Spark SQL 抽样函数 ——TABLESAMPLE 的坑点

最近需要实现一段 Spark SQL 逻辑,对数据集进行抽样指定的行数。

由于数据集较大,刚开始的逻辑是,取窗口函数随机排序后 row_number 的前 n 行。但运行速度较慢,所以想起了 TABLESAMLE 函数,支持直接取 Rows, 尝试后发现速度特别快,基本上几秒内就完成对亿级数据的采样。所以好奇就去查看文档和代码逻辑。

The TABLESAMPLE statement is used to sample the table. It supports the following sampling methods:

  • TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
  • TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages are defined as a number between 0 and 100.
  • TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a x out of y fraction.

Note: TABLESAMPLE returns the approximate number of rows or fraction requested.

文档中没有对实现逻辑有过多的说明,所以去代码中找问题。

源码中,匹配 SampleByRowsContext 时,调用的方法是 Limit(expression(ctx.expression), query),也就是说和 limit rows 是一个逻辑。

而 SampleByPercentileContext 实现的才是随机采样。

所以,如果对抽样的随机性有要求,还是老老实实用 SampleByPercentileContext,或者窗口函数。

附 相关代码:

<table><tbody><tr><td><p><br></p></td> <td> /**</td></tr> <tr><td><p><br></p></td> <td> * Add a [[Sample]] to a logical plan.</td></tr> <tr><td><p><br></p></td> <td> *</td></tr> <tr><td><p><br></p></td> <td> * This currently supports the following sampling methods:</td></tr> <tr><td><p><br></p></td> <td> * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.</td></tr> <tr><td><p><br></p></td> <td> * - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the given percentage with</td></tr> <tr><td><p><br></p></td> <td> * seed 'y'. Note that percentages are defined as a number between 0 and 100.</td></tr> <tr><td><p><br></p></td> <td> * - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down to a 'x' divided by</td></tr> <tr><td><p><br></p></td> <td> * 'y' fraction with seed 'z'.</td></tr> <tr><td><p><br></p></td> <td> */</td></tr> <tr><td><p><br></p></td> <td> private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {</td></tr> <tr><td><p><br></p></td> <td> // Create a sampled plan if we need one.</td></tr> <tr><td><p><br></p></td> <td> def sample(fraction: Double, seed: Long): Sample = {</td></tr> <tr><td><p><br></p></td> <td> // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling</td></tr> <tr><td><p><br></p></td> <td> // function takes X PERCENT as the input and the range of X is [0, 100], we need to</td></tr> <tr><td><p><br></p></td> <td> // adjust the fraction.</td></tr> <tr><td><p><br></p></td> <td> val eps = RandomSampler.roundingEpsilon</td></tr> <tr><td><p><br></p></td> <td> validate(fraction >= 0.0 - eps && fraction </td></tr> <tr><td><p><br></p></td> <td> Limit(expression(ctx.expression), query)</td></tr> <tr><td><p><br></p></td> <td> </td></tr> <tr><td><p><br></p></td> <td> case ctx: SampleByPercentileContext =></td></tr> <tr><td><p><br></p></td> <td> val fraction = ctx.percentage.getText.toDouble</td></tr> <tr><td><p><br></p></td> <td> val sign = if (ctx.negativeSign == null) 1 else -1</td></tr> <tr><td><p><br></p></td> <td> sample(sign * fraction / 100.0d, seed)</td></tr> <tr><td><p><br></p></td> <td> </td></tr> <tr><td><p><br></p></td> <td> case ctx: SampleByBytesContext =></td></tr> <tr><td><p><br></p></td> <td> val bytesStr = ctx.bytes.getText</td></tr> <tr><td><p><br></p></td> <td> if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {</td></tr> <tr><td><p><br></p></td> <td> throw QueryParsingErrors.tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx)</td></tr> <tr><td><p><br></p></td> <td> } else {</td></tr> <tr><td><p><br></p></td> <td> throw QueryParsingErrors.invalidByteLengthLiteralError(bytesStr, ctx)</td></tr> <tr><td><p><br></p></td> <td> }</td></tr> <tr><td><p><br></p></td> <td> </td></tr> <tr><td><p><br></p></td> <td> case ctx: SampleByBucketContext if ctx.ON() != null =></td></tr> <tr><td><p><br></p></td> <td> if (ctx.identifier != null) {</td></tr> <tr><td><p><br></p></td> <td> throw QueryParsingErrors.tableSampleByBytesUnsupportedError(</td></tr> <tr><td><p><br></p></td> <td> "BUCKET x OUT OF y ON colname", ctx)</td></tr> <tr><td><p><br></p></td> <td> } else {</td></tr> <tr><td><p><br></p></td> <td> throw QueryParsingErrors.tableSampleByBytesUnsupportedError(</td></tr> <tr><td><p><br></p></td> <td> "BUCKET x OUT OF y ON function", ctx)</td></tr> <tr><td><p><br></p></td> <td> }</td></tr> <tr><td><p><br></p></td> <td> </td></tr> <tr><td><p><br></p></td> <td> case ctx: SampleByBucketContext =></td></tr> <tr><td><p><br></p></td> <td> sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed)</td></tr> <tr><td><p><br></p></td> <td> }</td></tr> <tr><td><p><br></p></td> <td> }</td></tr></tbody></table>