• SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)


    背景

    本文基于 SPARK 3.3.0
    从一个unit test来探究SPARK Codegen的逻辑,

      test("SortAggregate should be included in WholeStageCodegen") {
        val df = spark.range(10).agg(max(col("id")), avg(col("id")))
        withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
          val plan = df.queryExecution.executedPlan
          assert(plan.exists(p =>
            p.isInstanceOf[WholeStageCodegenExec] &&
              p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
          assert(df.collect() === Array(Row(9, 4.5)))
        }
      }
    该sql形成的执行计划第一部分的全代码生成部分如下:
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    WholeStageCodegen
    
    • 1

    ± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
    ± *(1) Range (0, 10, step=1, splits=2)

    
    
    • 1

    分析

    第一阶段wholeStageCodegen

    第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
    第一阶段wholeStageCodegen数据流如下:

     WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
      =========================================================================
     
      -> execute()
          |
       doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
          |
       doCodeGen()
          |
          +----------------->   produce()
                                  |
                               doProduce() 
                                  |
                               doProduceWithoutKeys() -------> produce()
                                                                  |
                                                              doProduce()
                                                                  |
                               doConsume()<------------------- consume()
                                  |
                               doConsumeWithoutKeys()
                                  |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
       doConsume()  <--------  consume()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    RangeExec的produce
    final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
        this.parent = parent
        ctx.freshNamePrefix = variablePrefix
        s"""
           |${ctx.registerComment(s"PRODUCE: ${this.simpleString(conf.maxToStringFields)}")}
           |${doProduce(ctx)}
         """.stripMargin
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • this.parent = parent以及ctx.freshNamePrefix = variablePrefix
      设置parent 以便在做consume方法的时候能够获取到父节点的引用,这样才能调用到父节点的consume方法以便代码生成。
      freshNamePrefix的设置是为了在生成对应的方法的时候,区分不同物理计划的方法,这样能防止方法名重复,避免编译代码时出错。
    • ctx.registerComment
      这块是给java代码加上对应的注释,默认情况下是不会加上的,因为默认spark.sql.codegen.commentsFalse
    protected override def doProduce(ctx: CodegenContext): String = {
        val numOutput = metricTerm(ctx, "numOutputRows")
    
        val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
        val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
    
        val value = ctx.freshName("value")
        val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
        val BigInt = classOf[java.math.BigInteger].getName
    
        // Inline mutable state since not many Range operations in a task
        val taskContext = ctx.addMutableState("TaskContext", "taskContext",
          v => s"$v = TaskContext.get();", forceInline = true)
        val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
          v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)
    
        // In order to periodically update the metrics without inflicting performance penalty, this
        // operator produces elements in batches. After a batch is complete, the metrics are updated
        // and a new batch is started.
        // In the implementation below, the code in the inner loop is producing all the values
        // within a batch, while the code in the outer loop is setting batch parameters and updating
        // the metrics.
    
        // Once nextIndex == batchEnd, it's time to progress to the next batch.
        val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
    
        // How many values should still be generated by this range operator.
        val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
    
        // How many values should be generated in the next batch.
        val nextBatchTodo = ctx.freshName("nextBatchTodo")
    
        // The default size of a batch, which must be positive integer
        val batchSize = 1000
    
        val initRangeFuncName = ctx.addNewFunction("initRange",
          s"""
            | private void initRange(int idx) {
            |   $BigInt index = $BigInt.valueOf(idx);
            |   $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
            |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
            |   $BigInt step = $BigInt.valueOf(${step}L);
            |   $BigInt start = $BigInt.valueOf(${start}L);
            |   long partitionEnd;
            |
            |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
            |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
            |     $nextIndex = Long.MAX_VALUE;
            |   } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
            |     $nextIndex = Long.MIN_VALUE;
            |   } else {
            |     $nextIndex = st.longValue();
            |   }
            |   $batchEnd = $nextIndex;
            |
            |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
            |     .multiply(step).add(start);
            |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
            |     partitionEnd = Long.MAX_VALUE;
            |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
            |     partitionEnd = Long.MIN_VALUE;
            |   } else {
            |     partitionEnd = end.longValue();
            |   }
            |
            |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
            |     $BigInt.valueOf($nextIndex));
            |   $numElementsTodo  = startToEnd.divide(step).longValue();
            |   if ($numElementsTodo < 0) {
            |     $numElementsTodo = 0;
            |   } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
            |     $numElementsTodo++;
            |   }
            | }
           """.stripMargin)
    
        val localIdx = ctx.freshName("localIdx")
        val localEnd = ctx.freshName("localEnd")
        val stopCheck = if (parent.needStopCheck) {
          s"""
             |if (shouldStop()) {
             |  $nextIndex = $value + ${step}L;
             |  $numOutput.add($localIdx + 1);
             |  $inputMetrics.incRecordsRead($localIdx + 1);
             |  return;
             |}
           """.stripMargin
        } else {
          "// shouldStop check is eliminated"
        }
        val loopCondition = if (limitNotReachedChecks.isEmpty) {
          "true"
        } else {
          limitNotReachedChecks.mkString(" && ")
        }
        s"""
          | // initialize Range
          | if (!$initTerm) {
          |   $initTerm = true;
          |   $initRangeFuncName(partitionIndex);
          | }
          |
          | while ($loopCondition) {
          |   if ($nextIndex == $batchEnd) {
          |     long $nextBatchTodo;
          |     if ($numElementsTodo > ${batchSize}L) {
          |       $nextBatchTodo = ${batchSize}L;
          |       $numElementsTodo -= ${batchSize}L;
          |     } else {
          |       $nextBatchTodo = $numElementsTodo;
          |       $numElementsTodo = 0;
          |       if ($nextBatchTodo == 0) break;
          |     }
          |     $batchEnd += $nextBatchTodo * ${step}L;
          |   }
          |
          |   int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
          |   for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
          |     long $value = ((long)$localIdx * ${step}L) + $nextIndex;
          |     ${consume(ctx, Seq(ev))}
          |     $stopCheck
          |   }
          |   $nextIndex = $batchEnd;
          |   $numOutput.add($localEnd);
          |   $inputMetrics.incRecordsRead($localEnd);
          |   $taskContext.killTaskIfInterrupted();
          | }
         """.stripMargin
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • val numOutput = metricTerm(ctx, “numOutputRows”)
      numOutput指标,用于记录输出的记录条数

    • val initTerm =以及val nextIndex =
      initTerm用于标识该物理计划是够已经生成了代码,
      nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据
      这两个参数也是类的成员变量,即全局变量

    • val value =和val ev =
      这个ev值是用来表示rangExec生成的数据的,最终会被consume(ctx, Seq(ev))方法所调用
      而其中的value变量则是会在long $value = ((long)$localIdx * ${step}L) + $nextIndex;被赋值,这样父节点才能进行消费

    • val taskContext =和val inputMetrics =
      taskContextinputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:

      range_taskContext_0 = TaskContext.get();
      range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
      
      • 1
      • 2

      之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的
      数据,将会在WholeStageCodegenExecctx.initMutableStates()会被组装调用,被调用的代码如下:

       
       public void init(int index, scala.collection.Iterator[] inputs) {
             partitionIndex = index;
             this.inputs = inputs;
             ${ctx.initMutableStates()}
             ${ctx.initPartition()}
           }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
    • val batchEnd =和val numElementsTodo
      这两个变量也是生成类的成员变量,即全局变量

    • val nextBatchTodo =
      这个变量是临时变量,会在遍历生成数据的时候用到

    • val initRangeFuncName =
      就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略

    • 最后的while ($loopCondition)
      这部分就是根据每个分区的index不一样,生成不同的数据。
      值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,
      partitionIndex变量的赋值也在init方法中,具体代码如下:

      public void init(int index, scala.collection.Iterator[] inputs) {
              partitionIndex = index;
              this.inputs = inputs; 
      
      • 1
      • 2
      • 3
    • consume(ctx, Seq(ev))
      父节点进行消费rangeExec产生的数据,接下来会继续讲解

    • numOutput和inputMetrics和taskContext
      numOutput 进行输出数据的增加
      inputMetrics 在taskMetrics级别数据的增加
      taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常

  • 相关阅读:
    “程序包com.sun.tools.javac.util不存在” 问题解决
    [libevent:构建高性能事件驱动应用的利器]
    Java面试题:通过实例说明工厂模式和抽象工厂模式的用法,以及它们在解耦中的作用
    常用工具链和虚拟环境-msys2与mingw
    领悟《信号与系统》之 采样定理
    电力电子转战数字IC20220720day53——同步通信元件
    uos桌面专业版下载多架构软件安装包
    【直接运行TS文件的三种方法】
    计算机视觉之图像增广(翻转、随机裁剪、颜色变化[亮度、对比度、饱和度、色调])
    面试求职-经典面试问题
  • 原文地址:https://blog.csdn.net/monkeyboy_tech/article/details/126774608