本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") {
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
}
}
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen
± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
± *(1) Range (0, 10, step=1, splits=2)
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec
=========================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs()
|
doCodeGen()
|
+-----------------> produce()
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume()<------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
s"""
|${ctx.registerComment(s"PRODUCE: ${this.simpleString(conf.maxToStringFields)}")}
|${doProduce(ctx)}
""".stripMargin
}
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
val value = ctx.freshName("value")
val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
val BigInt = classOf[java.math.BigInteger].getName
// Inline mutable state since not many Range operations in a task
val taskContext = ctx.addMutableState("TaskContext", "taskContext",
v => s"$v = TaskContext.get();", forceInline = true)
val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)
// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
// and a new batch is started.
// In the implementation below, the code in the inner loop is producing all the values
// within a batch, while the code in the outer loop is setting batch parameters and updating
// the metrics.
// Once nextIndex == batchEnd, it's time to progress to the next batch.
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
// How many values should still be generated by this range operator.
val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")
// The default size of a batch, which must be positive integer
val batchSize = 1000
val initRangeFuncName = ctx.addNewFunction("initRange",
s"""
| private void initRange(int idx) {
| $BigInt index = $BigInt.valueOf(idx);
| $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
| $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
| $BigInt step = $BigInt.valueOf(${step}L);
| $BigInt start = $BigInt.valueOf(${start}L);
| long partitionEnd;
|
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
| $nextIndex = Long.MAX_VALUE;
| } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
| $nextIndex = Long.MIN_VALUE;
| } else {
| $nextIndex = st.longValue();
| }
| $batchEnd = $nextIndex;
|
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
| .multiply(step).add(start);
| if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
| partitionEnd = Long.MAX_VALUE;
| } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
| partitionEnd = Long.MIN_VALUE;
| } else {
| partitionEnd = end.longValue();
| }
|
| $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
| $BigInt.valueOf($nextIndex));
| $numElementsTodo = startToEnd.divide(step).longValue();
| if ($numElementsTodo < 0) {
| $numElementsTodo = 0;
| } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
| $numElementsTodo++;
| }
| }
""".stripMargin)
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val stopCheck = if (parent.needStopCheck) {
s"""
|if (shouldStop()) {
| $nextIndex = $value + ${step}L;
| $numOutput.add($localIdx + 1);
| $inputMetrics.incRecordsRead($localIdx + 1);
| return;
|}
""".stripMargin
} else {
"// shouldStop check is eliminated"
}
val loopCondition = if (limitNotReachedChecks.isEmpty) {
"true"
} else {
limitNotReachedChecks.mkString(" && ")
}
s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| $initRangeFuncName(partitionIndex);
| }
|
| while ($loopCondition) {
| if ($nextIndex == $batchEnd) {
| long $nextBatchTodo;
| if ($numElementsTodo > ${batchSize}L) {
| $nextBatchTodo = ${batchSize}L;
| $numElementsTodo -= ${batchSize}L;
| } else {
| $nextBatchTodo = $numElementsTodo;
| $numElementsTodo = 0;
| if ($nextBatchTodo == 0) break;
| }
| $batchEnd += $nextBatchTodo * ${step}L;
| }
|
| int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| long $value = ((long)$localIdx * ${step}L) + $nextIndex;
| ${consume(ctx, Seq(ev))}
| $stopCheck
| }
| $nextIndex = $batchEnd;
| $numOutput.add($localEnd);
| $inputMetrics.incRecordsRead($localEnd);
| $taskContext.killTaskIfInterrupted();
| }
""".stripMargin
}
val numOutput = metricTerm(ctx, “numOutputRows”)
numOutput指标,用于记录输出的记录条数
val initTerm =以及val nextIndex =
initTerm用于标识该物理计划是够已经生成了代码,
nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据
这两个参数也是类的成员变量,即全局变量
val value =和val ev =
这个ev值是用来表示rangExec生成的数据的,最终会被consume(ctx, Seq(ev))
方法所调用
而其中的value变量则是会在long $value = ((long)$localIdx * ${step}L) + $nextIndex;
被赋值,这样父节点才能进行消费
val taskContext =和val inputMetrics =
taskContext和inputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:
range_taskContext_0 = TaskContext.get();
range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的
数据,将会在WholeStageCodegenExec的ctx.initMutableStates()
会被组装调用,被调用的代码如下:
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
${ctx.initMutableStates()}
${ctx.initPartition()}
}
val batchEnd =和val numElementsTodo
这两个变量也是生成类的成员变量,即全局变量
val nextBatchTodo =
这个变量是临时变量,会在遍历生成数据的时候用到
val initRangeFuncName =
就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略
最后的while ($loopCondition)
这部分就是根据每个分区的index不一样,生成不同的数据。
值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,
而partitionIndex变量的赋值也在init方法中,具体代码如下:
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
consume(ctx, Seq(ev))
父节点进行消费rangeExec产生的数据,接下来会继续讲解
numOutput和inputMetrics和taskContext
numOutput 进行输出数据的增加
inputMetrics 在taskMetrics级别数据的增加
taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常