• SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(5)


    背景

    本文基于 SPARK 3.3.0
    从一个unit test来探究SPARK Codegen的逻辑,

      test("SortAggregate should be included in WholeStageCodegen") {
        val df = spark.range(10).agg(max(col("id")), avg(col("id")))
        withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
          val plan = df.queryExecution.executedPlan
          assert(plan.exists(p =>
            p.isInstanceOf[WholeStageCodegenExec] &&
              p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
          assert(df.collect() === Array(Row(9, 4.5)))
        }
      }
    该sql形成的执行计划第一部分的全代码生成部分如下:
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    WholeStageCodegen
    ± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
    ± *(1) Range (0, 10, step=1, splits=2)

    
    
    • 1

    分析

    第一阶段wholeStageCodegen

    第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
    第一阶段wholeStageCodegen数据流如下:

     WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
      =========================================================================
     
      -> execute()
          |
       doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
          |
       doCodeGen()
          |
          +----------------->   produce()
                                  |
                               doProduce() 
                                  |
                               doProduceWithoutKeys() -------> produce()
                                                                  |
                                                              doProduce()
                                                                  |
                               doConsume()<------------------- consume()
                                  |
                               doConsumeWithoutKeys()
                                  |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
       doConsume()  <--------  consume()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    SortAggregateExec(Partial)的doConsume方法
    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
        if (groupingExpressions.isEmpty) {
          doConsumeWithoutKeys(ctx, input)
        } else {
          doConsumeWithKeys(ctx, input)
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    注意这里虽然把ExprCode类型变量row传递进来了,但是在这个方法中却没有用到,因为对于大部分情况来说,该变量是对外部传递InteralRow的作用。
    而input则是sortAgg_expr_0_0,由rang_value_0赋值而来.
    doConsumeWithoutKeys对应的方法如下:

      private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
        // only have DeclarativeAggregate
        val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
        val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes
        // To individually generate code for each aggregate function, an element in `updateExprs` holds
        // all the expressions for the buffer of an aggregation function.
        val updateExprs = aggregateExpressions.map { e =>
          e.mode match {
            case Partial | Complete =>
              e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
            case PartialMerge | Final =>
              e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
          }
        }
        ctx.currentVars = bufVars.flatten ++ input
        println(s"updateExprs: $updateExprs")
        val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
          bindReferences(updateExprsForOneFunc, inputAttrs)
        }
        val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
        val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
        val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
          ctx.withSubExprEliminationExprs(subExprs.states) {
            boundUpdateExprsForOneFunc.map(_.genCode(ctx))
          }
        }
        val aggNames = functions.map(_.prettyName)
        val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) =>
          val bufVarsForOneFunc = bufVars(i)
          // All the update code for aggregation buffers should be placed in the end
          // of each aggregation function code.
          println(s"bufVarsForOneFunc: $bufVarsForOneFunc")
          val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>
            s"""
               |${bufVar.isNull} = ${ev.isNull};
               |${bufVar.value} = ${ev.value};
             """.stripMargin
          }
          code"""
                |${ctx.registerComment(s"do aggregate for ${aggNames(i)}")}
                |${ctx.registerComment("evaluate aggregate function")}
                |${evaluateVariables(bufferEvalsForOneFunc)}
                |${ctx.registerComment("update aggregation buffers")}
                |${updates.mkString("\n").trim}
           """.stripMargin
        }
    
        val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
          ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
        s"""
           |// do aggregate
           |// common sub-expressions
           |$effectiveCodes
           |// evaluate aggregate functions and update aggregation buffers
           |$codeToEvalAggFuncs
         """.stripMargin
      }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • val functions =和val inputAttrs =
      val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes,对于AVG聚合函数来说,聚合的缓冲属性(aggBufferAttributes)为AttributeReference("sum", sumDataType)()AttributeReference("count", LongType)().
      对于当前的计划来说,SortAggregateExec的inputAttributesAttributeReference("id", LongType, nullable = false)()

    • val updateExprs = aggregateExpressions.
      对于目前的物理计划来说,当前的modePartial,所以该值为updateExpressions,也就是局部更新,即

          Add(
          sum,
          coalesce(child.cast(sumDataType), Literal.default(sumDataType)),
          failOnError = useAnsiAdd),
        /* count = */ If(child.isNull, count, count + 1L)
      
      • 1
      • 2
      • 3
      • 4
      • 5
    • ctx.currentVars = bufVars.flatten ++ input
      这里的bufVars是在SortAggregateExec的produce方法进行赋值的,也就是对应“SUM”和“COUNT”初始值的ExprCode
      这里的input 是名为sortAgg_expr_0_0ExprCode变量

    • val boundUpdateExprs =
      把当前的输入变量绑定到updataExprs中去(很明显inputAttrs和currentVars是一一对应的)

    • val subExprs = 和val effectiveCodes =
      进行公共子表达式的消除,并提前计算出在计算子表达式计算之前的自表达式。
      对于当前的计划来说,该``effectiveCodes`为空字符串.

    • val bufferEvals =
      产生进行update的ExprCode,这里具体为(这里分别为Add和IF表达式的codegen:

      List(ExprCode(boolean sortAgg_isNull_7 = true;
           double sortAgg_value_7 = -1.0;
           if (!sortAgg_bufIsNull_1) {
             sortAgg_sortAgg_isNull_9_0 = true;
          double sortAgg_value_9 = -1.0;
          do {
          boolean sortAgg_isNull_10 = false;
          double sortAgg_value_10 = -1.0;
          if (!false) {
           sortAgg_value_10 = (double) sortAgg_expr_0_0;
          }
         if (!sortAgg_isNull_10) {
           sortAgg_sortAgg_isNull_9_0 = false;
           sortAgg_value_9 = sortAgg_value_10;
           continue;
         }
         if (!false) {
           sortAgg_sortAgg_isNull_9_0 = false;
           sortAgg_value_9 = 0.0D;
           continue;
         }
         } while (false);
         sortAgg_isNull_7 = false; // resultCode could change nullability.
         sortAgg_value_7 = sortAgg_bufValue_1 + sortAgg_value_9;
                 },sortAgg_isNull_7,sortAgg_value_7), 
      
         ExprCode(boolean sortAgg_isNull_13 = false;
         long sortAgg_value_13 = -1L;
         if (!false && false) {
      
           sortAgg_isNull_13 = sortAgg_bufIsNull_2;
           sortAgg_value_13 = sortAgg_bufValue_2;
         } else {
           boolean sortAgg_isNull_17 = true;
           long sortAgg_value_17 = -1L;
           if (!sortAgg_bufIsNull_2) {
         sortAgg_isNull_17 = false; // resultCode could change nullability.
                 
         sortAgg_value_17 = sortAgg_bufValue_2 + 1L;
                 }
           sortAgg_isNull_13 = sortAgg_isNull_17;
           sortAgg_value_13 = sortAgg_value_17;
         },sortAgg_isNull_13,sortAgg_value_13))
      
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
    • val aggNames = functions.map(_.prettyName)
      这里定义聚合函数的方法名字,最终会行成如下:sortAgg_doAggregate_avg_0类似这种名字的方法。

    • val aggCodeBlocks =
      这个是对应各个聚合函数的代码块,并在进行了聚合以后,把聚合的结果赋值给全局变量,对应的sql为:

        sortAgg_bufIsNull_1 = sortAgg_isNull_7;
        sortAgg_bufValue_1 = sortAgg_value_7;
      
        sortAgg_bufIsNull_2 = sortAgg_isNull_13;
        sortAgg_bufValue_2 = sortAgg_value_13;
      
      • 1
      • 2
      • 3
      • 4
      • 5

      其中sortAgg_bufValue_1代表了SUMsortAgg_bufValue_2代表COUNT

    • val codeToEvalAggFuncs = generateEvalCodeForAggFuncs
      生成各个聚合函数的代码,如下:

           sortAgg_doAggregate_max_0(sortAgg_expr_0_0);
           sortAgg_doAggregate_avg_0(sortAgg_expr_0_0);
      
      • 1
      • 2
    • $effectiveCodes
      组装代码

  • 相关阅读:
    vmware虚拟机安装centos7及网络配置
    论文查重前应删掉哪些内容?
    Vue(五)——使用脚手架(2)
    HarmonyOS鸿蒙学习笔记(2)路由ohos.router的使用
    【牛客-剑指offer-数据结构篇】JZ52 两个链表的第一个公共节点 两种思路 Java实现
    计算机网络学习笔记——运输层(b站)
    【C# 技术】C# 常用排序方式
    金仓KFS数据集中场景(多对一)部署
    9.3.2另一种计算机器2
    中国制库:创新引领,效率突破,塑造行业新标准
  • 原文地址:https://blog.csdn.net/monkeyboy_tech/article/details/126879155