最近需要实现一段 Spark SQL 逻辑,对数据集进行抽样指定的行数。
由于数据集较大,刚开始的逻辑是,取窗口函数随机排序后 row_number 的前 n 行。但运行速度较慢,所以想起了 TABLESAMLE 函数,支持直接取 Rows, 尝试后发现速度特别快,基本上几秒内就完成对亿级数据的采样。所以好奇就去查看文档和代码逻辑。
The
TABLESAMPLE
statement is used to sample the table. It supports the following sampling methods:
TABLESAMPLE
(xROWS
): Sample the table down to the given number of rows.TABLESAMPLE
(xPERCENT
): Sample the table down to the given percentage. Note that percentages are defined as a number between 0 and 100.TABLESAMPLE
(BUCKET
xOUT OF
y): Sample the table down to ax
out ofy
fraction.Note:
TABLESAMPLE
returns the approximate number of rows or fraction requested.
文档中没有对实现逻辑有过多的说明,所以去代码中找问题。
源码中,匹配 SampleByRowsContext
时,调用的方法是 Limit(expression(ctx.expression), query)
,也就是说和 limit rows
是一个逻辑。
而 SampleByPercentileContext
实现的才是随机采样。
所以,如果对抽样的随机性有要求,还是老老实实用 SampleByPercentileContext
,或者窗口函数。
附 相关代码:
/** | |
* Add a [[Sample]] to a logical plan. | |
* | |
* This currently supports the following sampling methods: | |
* - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. | |
* - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the given percentage with | |
* seed 'y'. Note that percentages are defined as a number between 0 and 100. | |
* - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down to a 'x' divided by | |
* 'y' fraction with seed 'z'. | |
*/ | |
private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { | |
// Create a sampled plan if we need one. | |
def sample(fraction: Double, seed: Long): Sample = { | |
// The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling | |
// function takes X PERCENT as the input and the range of X is [0, 100], we need to | |
// adjust the fraction. | |
val eps = RandomSampler.roundingEpsilon | |
validate(fraction >= 0.0 - eps && fraction | |
Limit(expression(ctx.expression), query) | |
case ctx: SampleByPercentileContext => | |
val fraction = ctx.percentage.getText.toDouble | |
val sign = if (ctx.negativeSign == null) 1 else -1 | |
sample(sign * fraction / 100.0d, seed) | |
case ctx: SampleByBytesContext => | |
val bytesStr = ctx.bytes.getText | |
if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { | |
throw QueryParsingErrors.tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx) | |
} else { | |
throw QueryParsingErrors.invalidByteLengthLiteralError(bytesStr, ctx) | |
} | |
case ctx: SampleByBucketContext if ctx.ON() != null => | |
if (ctx.identifier != null) { | |
throw QueryParsingErrors.tableSampleByBytesUnsupportedError( | |
"BUCKET x OUT OF y ON colname", ctx) | |
} else { | |
throw QueryParsingErrors.tableSampleByBytesUnsupportedError( | |
"BUCKET x OUT OF y ON function", ctx) | |
} | |
case ctx: SampleByBucketContext => | |
sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed) | |
} | |
} |