• SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明


    背景

    对于在在RangeExec中出现的GenerateUnsafeProjection.createCode的方法进行说明

    分析

    对应的代码为:

      def createCode(
          ctx: CodegenContext,
          expressions: Seq[Expression],
          useSubexprElimination: Boolean = false): ExprCode = {
        val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
        val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
    
        val numVarLenFields = exprSchemas.count {
          case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
          // TODO: consider large decimal and interval type
        }
    
        val rowWriterClass = classOf[UnsafeRowWriter].getName
        val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
          v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});")
    
        // Evaluate all the subexpression.
        val evalSubexpr = ctx.subexprFunctionsCode
    
        val writeExpressions = writeExpressionsToBuffer(
          ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
    //   println(s"writeExpressions: $writeExpressions")
        val code =
          code"""
             |$rowWriter.reset();
             |$evalSubexpr
             |$writeExpressions
           """.stripMargin
        // `rowWriter` is declared as a class field, so we can access it directly in methods.
    //    println(s"code: $code")
        ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
      }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    其中 expressions的值为Seq(BoundReference(0, long, false))
    useSubexpreEliminationfalse

    • val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
      这里只是代码生成,exprEvals的值就是range_value_0
      因为useSubexprEliminationfalse,所以不会进行公共代码的消除
    • val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
      得到对应的表达式的schema
    • val numVarLenFields =
      计算出非固定长度字段的个数,用于初始化UnsafeRowWriter
    • val rowWriter =
      定义并初始化rowWriter,该rowWriter是全局范围的,生成的代码如下:
       private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
       public void init(int index, scala.collection.Iterator[] inputs) {
       ...
       range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
       }
      
      • 1
      • 2
      • 3
      • 4
      • 5
    • val evalSubexpr = ctx.subexprFunctionsCode
      这里为空字符串
    • val writeExpressions = writeExpressionsToBuffer
      private def writeExpressionsToBuffer(
          ctx: CodegenContext,
          row: String,
          inputs: Seq[ExprCode],
          schemas: Seq[Schema],
          rowWriter: String,
          isTopLevel: Boolean = false): String = {
        val resetWriter = if (isTopLevel) {
          // For top level row writer, it always writes to the beginning of the global buffer holder,
          // which means its fixed-size region always in the same position, so we don't need to call
          // `reset` to set up its fixed-size region every time.
          if (inputs.map(_.isNull).forall(_ == FalseLiteral)) {
            // If all fields are not nullable, which means the null bits never changes, then we don't
            // need to clear it out every time.
            ""
          } else {
            s"$rowWriter.zeroOutNullBytes();"
          }
        } else {
          s"$rowWriter.resetRowWriter();"
        }
      
        val writeFields = inputs.zip(schemas).zipWithIndex.map {
          case ((input, Schema(dataType, nullable)), index) =>
            val dt = UserDefinedType.sqlType(dataType)
      
            val setNull = dt match {
              case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
                // Can't call setNullAt() for DecimalType with precision larger than 18.
                s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
              case CalendarIntervalType => s"$rowWriter.write($index, (CalendarInterval) null);"
              case _ => s"$rowWriter.setNullAt($index);"
            }
      
            val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
            if (!nullable) {
              s"""
                 |${input.code}
                 |${writeField.trim}
               """.stripMargin
            } else {
              s"""
                 |${input.code}
                 |if (${input.isNull}) {
                 |  ${setNull.trim}
                 |} else {
                 |  ${writeField.trim}
                 |}
               """.stripMargin
            }
        }
      
        val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) {
          // TODO: support whole stage codegen
          writeFields.mkString("\n")
        } else {
          assert(row != null, "the input row name cannot be null when generating code to write it.")
          ctx.splitExpressions(
            expressions = writeFields,
            funcName = "writeFields",
            arguments = Seq("InternalRow" -> row))
        }
        s"""
           |$resetWriter
           |$writeFieldsCode
         """.stripMargin
      }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
      • 55
      • 56
      • 57
      • 58
      • 59
      • 60
      • 61
      • 62
      • 63
      • 64
      • 65
      • 66
      • 67
      • val resetWriter =
        因为inputs为null为false,所以resetWriter的值为空字符串

      • val writeFields =
        因为inputs的类型是LONG类型,所以对应到val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)代码为:
        case _ => s"$writer.write($index, $input);",所以生成的代码为:

         range_mutableStateArray_0[0].write(0, range_value_0)
        
        • 1
      • val writeFieldsCode =以及后面的代码组装
        对每一个变量的赋值按照换行符进行分隔。

    • val code =
      组装成ExprCode的code部分,生成的代码如下:
      range_mutableStateArray_0[0].reset();
      
      range_mutableStateArray_0[0].write(0, range_value_0);
      
      • 1
      • 2
      • 3

    最后ExprCode的完整部分如下:

     ExprCode(range_mutableStateArray_0[0].reset();
    range_mutableStateArray_0[0].write(0, range_value_0);,false,(range_mutableStateArray_0[0].getRow()))
    
    • 1
    • 2
  • 相关阅读:
    【图像检测】基于计算机视觉实现椭圆检测附matlab代码
    springboot 如何解决循环依赖
    如何挖掘xss漏洞
    【uniapp/uview】Collapse 折叠面板更改右侧小箭头图标
    掌握了这几个 Linux 命令可以让你工作效率提升三倍
    假期酒店价格一路涨价,专家称住便宜酒店的时代可能过去了
    深度学习每周学习总结P5(运动鞋识别)
    Android Gradle插件对应的Gradle脚本所需版本
    springboot+springsecurity+elementui博客系统-dsblog
    【PyTorch深度学习项目实战100例】—— 利用pytorch长短期记忆网络LSTM实现股票预测分析 | 第5例
  • 原文地址:https://blog.csdn.net/monkeyboy_tech/article/details/126859545