对于在在RangeExec中出现的GenerateUnsafeProjection.createCode的方法进行说明
对应的代码为:
def createCode(
ctx: CodegenContext,
expressions: Seq[Expression],
useSubexprElimination: Boolean = false): ExprCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
val numVarLenFields = exprSchemas.count {
case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
// TODO: consider large decimal and interval type
}
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});")
// Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctionsCode
val writeExpressions = writeExpressionsToBuffer(
ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
// println(s"writeExpressions: $writeExpressions")
val code =
code"""
|$rowWriter.reset();
|$evalSubexpr
|$writeExpressions
""".stripMargin
// `rowWriter` is declared as a class field, so we can access it directly in methods.
// println(s"code: $code")
ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
}
其中 expressions的值为Seq(BoundReference(0, long, false))
useSubexpreElimination为false
range_value_0useSubexprElimination 是false,所以不会进行公共代码的消除schemaUnsafeRowWriter private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
public void init(int index, scala.collection.Iterator[] inputs) {
...
range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
}
private def writeExpressionsToBuffer(
ctx: CodegenContext,
row: String,
inputs: Seq[ExprCode],
schemas: Seq[Schema],
rowWriter: String,
isTopLevel: Boolean = false): String = {
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
// `reset` to set up its fixed-size region every time.
if (inputs.map(_.isNull).forall(_ == FalseLiteral)) {
// If all fields are not nullable, which means the null bits never changes, then we don't
// need to clear it out every time.
""
} else {
s"$rowWriter.zeroOutNullBytes();"
}
} else {
s"$rowWriter.resetRowWriter();"
}
val writeFields = inputs.zip(schemas).zipWithIndex.map {
case ((input, Schema(dataType, nullable)), index) =>
val dt = UserDefinedType.sqlType(dataType)
val setNull = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
// Can't call setNullAt() for DecimalType with precision larger than 18.
s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
case CalendarIntervalType => s"$rowWriter.write($index, (CalendarInterval) null);"
case _ => s"$rowWriter.setNullAt($index);"
}
val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
if (!nullable) {
s"""
|${input.code}
|${writeField.trim}
""".stripMargin
} else {
s"""
|${input.code}
|if (${input.isNull}) {
| ${setNull.trim}
|} else {
| ${writeField.trim}
|}
""".stripMargin
}
}
val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) {
// TODO: support whole stage codegen
writeFields.mkString("\n")
} else {
assert(row != null, "the input row name cannot be null when generating code to write it.")
ctx.splitExpressions(
expressions = writeFields,
funcName = "writeFields",
arguments = Seq("InternalRow" -> row))
}
s"""
|$resetWriter
|$writeFieldsCode
""".stripMargin
}
val resetWriter =
因为inputs为null为false,所以resetWriter的值为空字符串
val writeFields =
因为inputs的类型是LONG类型,所以对应到val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)代码为:
case _ => s"$writer.write($index, $input);",所以生成的代码为:
range_mutableStateArray_0[0].write(0, range_value_0)
val writeFieldsCode =以及后面的代码组装
对每一个变量的赋值按照换行符进行分隔。
range_mutableStateArray_0[0].reset();
range_mutableStateArray_0[0].write(0, range_value_0);
最后ExprCode的完整部分如下:
ExprCode(range_mutableStateArray_0[0].reset();
range_mutableStateArray_0[0].write(0, range_value_0);,false,(range_mutableStateArray_0[0].getRow()))