• SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(8)


    背景

    本文基于 SPARK 3.3.0
    从一个unit test来探究SPARK Codegen的逻辑,

      test("SortAggregate should be included in WholeStageCodegen") {
        val df = spark.range(10).agg(max(col("id")), avg(col("id")))
        withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
          val plan = df.queryExecution.executedPlan
          assert(plan.exists(p =>
            p.isInstanceOf[WholeStageCodegenExec] &&
              p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
          assert(df.collect() === Array(Row(9, 4.5)))
        }
      }
    

    该sql形成的执行计划第二部分的全代码生成部分如下:

    WholeStageCodegen
    *(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])
       InputAdapter
    +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]
    

    分析

    第二阶段wholeStageCodegen

    第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:
    第二阶段wholeStageCodegen数据流如下:

     WholeStageCodegenExec      SortAggregateExec(Final)      InputAdapter       ShuffleExchangeExec        
      ====================================================================================
     
      -> execute()
          |
       doExecute() --------->   inputRDDs() -----------------> inputRDDs() -------> execute()
          |                                                                            |
       doCodeGen()                                                                  doExecute()     
          |                                                                            |
          +----------------->   produce()                                           ShuffledRowRDD
                                  |
                               doProduce() 
                                  |
                               doProduceWithoutKeys() -------> produce()
                                                                  |
                                                              doProduce()
                                                                  |
                               doConsume() <------------------- consume()
                                  |
                               doConsumeWithoutKeys()
                                  |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
       doConsume()  <--------  consume()
    
    
    SortAggregateExec(Final) 的doProduce

    这里只列出和SortAggregateExec(Partial)的不同的部分:

        val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
          // evaluate aggregate results
          ctx.currentVars = flatBufVars
          val aggResults = bindReferences(
            functions.map(_.evaluateExpression),
            aggregateBufferAttributes).map(_.genCode(ctx))
          val evaluateAggResults = evaluateVariables(aggResults)
          // evaluate result expressions
          ctx.currentVars = aggResults
          val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
          (resultVars,
            s"""
               |$evaluateAggResults
               |${evaluateVariables(resultVars)}
             """.stripMargin)
    
    • 因为我们这里是Final部分,所以我们的数据流和Partial是不同的
    • ctx.currentVars = flatBufVars
      赋值currentVars为当前buffer变量,便于下面进行数据绑定,该buffer变量是全局变量
    • val aggResults = bindReferences
      1. functions.map(_.evaluateExpression) 这是对最终输出结果的计算,对于SUM来说是Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) ,生成的代码如下:
         boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2;
         double sortAgg_value_6 = -1.0;
         if (!sortAgg_bufIsNull_2) {
           sortAgg_value_6 = (double) sortAgg_bufValue_2;
         }
         boolean sortAgg_isNull_4 = false;
         double sortAgg_value_4 = -1.0;
         if (sortAgg_isNull_6 || sortAgg_value_6 == 0) {
           sortAgg_isNull_4 = true;
         } else {
           if (sortAgg_bufIsNull_1) {
             sortAgg_isNull_4 = true;
           } else {
             sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);
           }
         }
      
      1. aggregateBufferAttributes 聚合函数的buffer属性值 sum :: count :: Nil
        这样在绑定数据的变量数据的时候和currentVars是一一对应的
    • val evaluateAggResults = evaluateVariables(aggResults)
      对聚合的结果进行最终的计算
    • ctx.currentVars = aggResults
      把最终结果的变量赋值给currentVars,便于后面的数据绑定
    • val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
      这一步是把聚合结果的变量绑定到聚合表达式中,
      其中resultExpressionsList( avg(id#0L)#3 AS avg(id)#6) (这里我们只考虑AVG)
      aggregateAttributesresultExpressionAttributeReference的一种表达,便于在BoundReference的时候进行映射绑定
      对应的ExprCode为ExprCode(,sortAgg_isNull_4,sortAgg_value_4))
    InputAdaptor的 doProduce

    InputAdaptor的主要作用是承上启下,用来适配不支持Codegen的物理计划,sql如下:

      override def doProduce(ctx: CodegenContext): String = {
     // Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen
     val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
       forceInline = true)
     val row = ctx.freshName("row")
    
     val outputVars = if (createUnsafeProjection) {
       // creating the vars will make the parent consume add an unsafe projection.
       ctx.INPUT_ROW = row
       ctx.currentVars = null
       output.zipWithIndex.map { case (a, i) =>
         BoundReference(i, a.dataType, a.nullable).genCode(ctx)
       }
     } else {
       null
     }
    
     val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) {
       val numOutputRows = metricTerm(ctx, "numOutputRows")
       s"$numOutputRows.add(1);"
     } else {
       ""
     }
     s"""
        | while ($limitNotReachedCond $input.hasNext()) {
        |   InternalRow $row = (InternalRow) $input.next();
        |   ${updateNumOutputRowsMetrics}
        |   ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}
        |   ${shouldStopCheckCode}
        | }
      """.stripMargin
    }
    
    • val input = ctx.addMutableState(“scala.collection.Iterator”, “input”, v => s"$v = inputs[0];"
      定义一个input变量用来接受sortaggregate(partial)的输出的InteralRow(unsafeRow),对应的初始化方法会在init方法中调用
    • val row = ctx.freshName(“row”)
      定义一个临时变量用来接受input中的unsafe类型的InteralRow,便于进行迭代操作
    • val outputVars = if (createUnsafeProjection)
      对于InputAdaptor来说createUnsafeProjectionfalse, 所以这块返回的是null
    • val updateNumOutputRowsMetrics =
      因为metrics不满足条件,所以这里也是返回空字符串
    • 代码组装
          s"""
         | while ($limitNotReachedCond $input.hasNext()) {
         |   InternalRow $row = (InternalRow) $input.next();
         |   ${updateNumOutputRowsMetrics}
         |   ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}
         |   ${shouldStopCheckCode}
         | }
       """.stripMargin
      
      对输入的每一行数据进行迭代操作, 之后再调用consume方法,
      注意: 这里的consume传入的是row,是InteralRow类型,而不是在RangeExec中的Long类型的变量
    InputAdaptor的 consume

    我们这里只说明和之前不一样的部分,对应的sql如下:

      final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String =
    

    注意这里的参数 outputVarsnull
    rowInteralRow类型的变量

    • val inputVarsCandidate =
     val inputVarsCandidate =
       if (outputVars != null) {
         assert(outputVars.length == output.length)
         // outputVars will be used to generate the code for UnsafeRow, so we should copy them
         outputVars.map(_.copy())
       } else {
         assert(row != null, "outputVars and row cannot both be null.")
         ctx.currentVars = null
         ctx.INPUT_ROW = row
         output.zipWithIndex.map { case (attr, i) =>
           BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
         }
      }
    

    这里的数据流向了 else :

    • ctx.INPUT_ROW = row
      设置当前的INPUT_ROWrow
      BoundReferencedoGenCode方法也是走向了另一个分支:
       assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
       val javaType = JavaCode.javaType(dataType)
       val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
       if (nullable) {
         ev.copy(code =
           code"""
              |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
              |$javaType ${ev.value} = ${ev.isNull} ?
              |  ${CodeGenerator.defaultValue(dataType)} : ($value);
            """.stripMargin)
       } else {
         ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
       }
    
    • 分析
      • val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType,ordinal.toString)
        根据数据类型的不同,调用UnsafeRow的不同方法

      • if (nullable)
        因为AttributeReference("sum", sumDataType)()AttributeReference("count", LongType)()表达式 nullableTRUE,所以生成的代码为:

        boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
        long inputadapter_value_0 = inputadapter_isNull_0 ?
        -1L : (inputadapter_row_0.getLong(0));
        boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
        double inputadapter_value_1 = inputadapter_isNull_1 ?
        -1.0 : (inputadapter_row_0.getDouble(1));
        boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);
        long inputadapter_value_2 = inputadapter_isNull_2 ?
        -1L : (inputadapter_row_0.getLong(2));
        
    • constructDoConsumeFunction方法中inputVarsInFunc
      这里会多一个名为inputadapter_row_0的InternalRow类型的实参
  • 相关阅读:
    LaTeX:在标题section中添加脚注footnote
    数据库基本结论
    LeetCode 面试题 10.03. 搜索旋转数组
    数字逻辑设计(5)
    # 数据库开发-MySQL基础DDL-DML总结
    2022牛客多校联赛加赛 题解
    【艾特淘】8月22日之后,抖音精选联盟准入标准变了
    Windows11安装Vim编辑器配置指南
    uniapp中vue3使用uni.createSelectorQuery().in(this)报错
    uni.app小程序的ajax封装详细讲解
  • 原文地址:https://blog.csdn.net/monkeyboy_tech/article/details/126941773