本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen") {
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
}
}
该sql形成的执行计划第二部分的全代码生成部分如下:
WholeStageCodegen
*(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])
InputAdapter
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]
第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:
第二阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(Final) InputAdapter ShuffleExchangeExec
====================================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs() -------> execute()
| |
doCodeGen() doExecute()
| |
+-----------------> produce() ShuffledRowRDD
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume() <------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
sortAggregateExec的doConsume
方法最终会调用doConsumeWithoutKeys(ctx, input)
方法,其中input为ArrayBuffer(ExprCode(,sortAgg_exprIsNull_0_0,sortAgg_expr_0_0), ExprCode(,sortAgg_exprIsNull_1_0,sortAgg_expr_1_0), ExprCode(,sortAgg_exprIsNull_2_0,sortAgg_expr_2_0))
和SortAggregateExec(Partial)不同点:
val updateExprs = aggregateExpressions.map { e =>
e.mode match {
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
case PartialMerge | Final =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
因为这里是Final
阶段,所以更新语句是mergeExpressions
,即
protected def getMergeExpressions = Seq(
/* sum = */ Add(sum.left, sum.right, useAnsiAdd),
/* count = */ count.left + count.right
生成对应的ExprCode类为:
List(ExprCode(boolean sortAgg_isNull_13 = true;
double sortAgg_value_13 = -1.0;
if (!sortAgg_bufIsNull_1) {
if (!sortAgg_exprIsNull_1_0) {
sortAgg_isNull_13 = false; // resultCode could change nullability.
sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0;
}
},sortAgg_isNull_13,sortAgg_value_13),
ExprCode(boolean sortAgg_isNull_16 = true;
long sortAgg_value_16 = -1L;
if (!sortAgg_bufIsNull_2) {
if (!sortAgg_exprIsNull_2_0) {
sortAgg_isNull_16 = false; // resultCode could change nullability.
sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0;
}
},sortAgg_isNull_16,sortAgg_value_16)))
可以看到,对于sortAgg_value_16
也就是COUNT
的值,这里是+ sortAgg_expr_2_0
,而在Partial
部分则是+ 1
,因为这是各个Task local
操作完后的结果buffer。
sortAgg_isNull_13
和sortAgg_value_16
的最终结果还是会赋值给全局变量:
sortAgg_bufIsNull_1 = sortAgg_isNull_13;
sortAgg_bufValue_1 = sortAgg_value_13;
sortAgg_bufIsNull_2 = sortAgg_isNull_16;
sortAgg_bufValue_2 = sortAgg_value_16;
consume(ctx, resultVars)
其中resultVars
为ExprCode(,sortAgg_isNull_4,sortAgg_value_4)
,这里包含了最终的AVG
的结果值。
其他的数据流向和之前的一样,
sortAgg_mutableStateArray_0[0].reset();
sortAgg_mutableStateArray_0[0].zeroOutNullBytes();
if (sortAgg_bufIsNull_0) {
sortAgg_mutableStateArray_0[0].setNullAt(0);
} else {
sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0);
}
if (sortAgg_isNull_4) {
sortAgg_mutableStateArray_0[0].setNullAt(1);
} else {
sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4);
}
数据流和之前的一致,最终的生成的代码如下:
append((sortAgg_mutableStateArray_0[0].getRow()));
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean sortAgg_initAgg_0;
/* 010 */ private boolean sortAgg_bufIsNull_0;
/* 011 */ private long sortAgg_bufValue_0;
/* 012 */ private boolean sortAgg_bufIsNull_1;
/* 013 */ private double sortAgg_bufValue_1;
/* 014 */ private boolean sortAgg_bufIsNull_2;
/* 015 */ private long sortAgg_bufValue_2;
/* 016 */ private scala.collection.Iterator inputadapter_input_0;
/* 017 */ private boolean sortAgg_sortAgg_isNull_10_0;
/* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] sortAgg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 019 */
/* 020 */ public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 021 */ this.references = references;
/* 022 */ }
/* 023 */
/* 024 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */ partitionIndex = index;
/* 026 */ this.inputs = inputs;
/* 027 */
/* 028 */ inputadapter_input_0 = inputs[0];
/* 029 */ sortAgg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 030 */
/* 031 */ }
/* 032 */
/* 033 */ private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0) throws java.io.IOException {
/* 034 */ sortAgg_sortAgg_isNull_10_0 = true;
/* 035 */ long sortAgg_value_10 = -1L;
/* 036 */
/* 037 */ if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_10_0 ||
/* 038 */ sortAgg_bufValue_0 > sortAgg_value_10)) {
/* 039 */ sortAgg_sortAgg_isNull_10_0 = false;
/* 040 */ sortAgg_value_10 = sortAgg_bufValue_0;
/* 041 */ }
/* 042 */
/* 043 */ if (!sortAgg_exprIsNull_0_0 && (sortAgg_sortAgg_isNull_10_0 ||
/* 044 */ sortAgg_expr_0_0 > sortAgg_value_10)) {
/* 045 */ sortAgg_sortAgg_isNull_10_0 = false;
/* 046 */ sortAgg_value_10 = sortAgg_expr_0_0;
/* 047 */ }
/* 048 */
/* 049 */ sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_10_0;
/* 050 */ sortAgg_bufValue_0 = sortAgg_value_10;
/* 051 */ }
/* 052 */
/* 053 */ private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException {
/* 054 */ // initialize aggregation buffer
/* 055 */ sortAgg_bufIsNull_0 = true;
/* 056 */ sortAgg_bufValue_0 = -1L;
/* 057 */ sortAgg_bufIsNull_1 = false;
/* 058 */ sortAgg_bufValue_1 = 0.0D;
/* 059 */ sortAgg_bufIsNull_2 = false;
/* 060 */ sortAgg_bufValue_2 = 0L;
/* 061 */
/* 062 */ while ( inputadapter_input_0.hasNext()) {
/* 063 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 064 */
/* 065 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
/* 066 */ long inputadapter_value_0 = inputadapter_isNull_0 ?
/* 067 */ -1L : (inputadapter_row_0.getLong(0));
/* 068 */ boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
/* 069 */ double inputadapter_value_1 = inputadapter_isNull_1 ?
/* 070 */ -1.0 : (inputadapter_row_0.getDouble(1));
/* 071 */ boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);
/* 072 */ long inputadapter_value_2 = inputadapter_isNull_2 ?
/* 073 */ -1L : (inputadapter_row_0.getLong(2));
/* 074 */
/* 075 */ sortAgg_doConsume_0(inputadapter_row_0, inputadapter_value_0, inputadapter_isNull_0, inputadapter_value_1, inputadapter_isNull_1, inputadapter_value_2, inputadapter_isNull_2);
/* 076 */ // shouldStop check is eliminated
/* 077 */ }
/* 078 */
/* 079 */ }
/* 080 */
/* 081 */ protected void processNext() throws java.io.IOException {
/* 082 */ while (!sortAgg_initAgg_0) {
/* 083 */ sortAgg_initAgg_0 = true;
/* 084 */ sortAgg_doAggregateWithoutKey_0();
/* 085 */
/* 086 */ // output the result
/* 087 */ boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2;
/* 088 */ double sortAgg_value_6 = -1.0;
/* 089 */ if (!sortAgg_bufIsNull_2) {
/* 090 */ sortAgg_value_6 = (double) sortAgg_bufValue_2;
/* 091 */ }
/* 092 */ boolean sortAgg_isNull_4 = false;
/* 093 */ double sortAgg_value_4 = -1.0;
/* 094 */ if (sortAgg_isNull_6 || sortAgg_value_6 == 0) {
/* 095 */ sortAgg_isNull_4 = true;
/* 096 */ } else {
/* 097 */ if (sortAgg_bufIsNull_1) {
/* 098 */ sortAgg_isNull_4 = true;
/* 099 */ } else {
/* 100 */ sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);
/* 101 */ }
/* 102 */ }
/* 103 */
/* 104 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 105 */ sortAgg_mutableStateArray_0[0].reset();
/* 106 */
/* 107 */ sortAgg_mutableStateArray_0[0].zeroOutNullBytes();
/* 108 */
/* 109 */ if (sortAgg_bufIsNull_0) {
/* 110 */ sortAgg_mutableStateArray_0[0].setNullAt(0);
/* 111 */ } else {
/* 112 */ sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0);
/* 113 */ }
/* 114 */
/* 115 */ if (sortAgg_isNull_4) {
/* 116 */ sortAgg_mutableStateArray_0[0].setNullAt(1);
/* 117 */ } else {
/* 118 */ sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4);
/* 119 */ }
/* 120 */ append((sortAgg_mutableStateArray_0[0].getRow()));
/* 121 */ }
/* 122 */ }
/* 123 */
/* 124 */ private void sortAgg_doConsume_0(InternalRow inputadapter_row_0, long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_1_0, long sortAgg_expr_2_0, boolean sortAgg_exprIsNull_2_0) throws java.io.IOException {
/* 125 */ // do aggregate
/* 126 */ // common sub-expressions
/* 127 */
/* 128 */ // evaluate aggregate functions and update aggregation buffers
/* 129 */ sortAgg_doAggregate_max_0(sortAgg_expr_0_0, sortAgg_exprIsNull_0_0);
/* 130 */ sortAgg_doAggregate_avg_0(sortAgg_exprIsNull_1_0, sortAgg_expr_1_0, sortAgg_exprIsNull_2_0, sortAgg_expr_2_0);
/* 131 */
/* 132 */ }
/* 133 */
/* 134 */ private void sortAgg_doAggregate_avg_0(boolean sortAgg_exprIsNull_1_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_2_0, long sortAgg_expr_2_0) throws java.io.IOException {
/* 135 */ boolean sortAgg_isNull_13 = true;
/* 136 */ double sortAgg_value_13 = -1.0;
/* 137 */
/* 138 */ if (!sortAgg_bufIsNull_1) {
/* 139 */ if (!sortAgg_exprIsNull_1_0) {
/* 140 */ sortAgg_isNull_13 = false; // resultCode could change nullability.
/* 141 */
/* 142 */ sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0;
/* 143 */
/* 144 */ }
/* 145 */
/* 146 */ }
/* 147 */ boolean sortAgg_isNull_16 = true;
/* 148 */ long sortAgg_value_16 = -1L;
/* 149 */
/* 150 */ if (!sortAgg_bufIsNull_2) {
/* 151 */ if (!sortAgg_exprIsNull_2_0) {
/* 152 */ sortAgg_isNull_16 = false; // resultCode could change nullability.
/* 153 */
/* 154 */ sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0;
/* 155 */
/* 156 */ }
/* 157 */
/* 158 */ }
/* 159 */
/* 160 */ sortAgg_bufIsNull_1 = sortAgg_isNull_13;
/* 161 */ sortAgg_bufValue_1 = sortAgg_value_13;
/* 162 */
/* 163 */ sortAgg_bufIsNull_2 = sortAgg_isNull_16;
/* 164 */ sortAgg_bufValue_2 = sortAgg_value_16;
/* 165 */ }
/* 166 */
/* 167 */ }