• SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(9)


    背景

    本文基于 SPARK 3.3.0
    从一个unit test来探究SPARK Codegen的逻辑,

      test("SortAggregate should be included in WholeStageCodegen") {
        val df = spark.range(10).agg(max(col("id")), avg(col("id")))
        withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
          val plan = df.queryExecution.executedPlan
          assert(plan.exists(p =>
            p.isInstanceOf[WholeStageCodegenExec] &&
              p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
          assert(df.collect() === Array(Row(9, 4.5)))
        }
      }
    

    该sql形成的执行计划第二部分的全代码生成部分如下:

    WholeStageCodegen
    *(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])
       InputAdapter
    +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]
    

    分析

    第二阶段wholeStageCodegen

    第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:
    第二阶段wholeStageCodegen数据流如下:

     WholeStageCodegenExec      SortAggregateExec(Final)      InputAdapter       ShuffleExchangeExec        
      ====================================================================================
     
      -> execute()
          |
       doExecute() --------->   inputRDDs() -----------------> inputRDDs() -------> execute()
          |                                                                            |
       doCodeGen()                                                                  doExecute()     
          |                                                                            |
          +----------------->   produce()                                           ShuffledRowRDD
                                  |
                               doProduce() 
                                  |
                               doProduceWithoutKeys() -------> produce()
                                                                  |
                                                              doProduce()
                                                                  |
                               doConsume() <------------------- consume()
                                  |
                               doConsumeWithoutKeys()
                                  |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
       doConsume()  <--------  consume()
    
    
    SortAggregateExec(Final) 的doConsume

    sortAggregateExec的doConsume方法最终会调用doConsumeWithoutKeys(ctx, input)方法,其中input为ArrayBuffer(ExprCode(,sortAgg_exprIsNull_0_0,sortAgg_expr_0_0), ExprCode(,sortAgg_exprIsNull_1_0,sortAgg_expr_1_0), ExprCode(,sortAgg_exprIsNull_2_0,sortAgg_expr_2_0))
    和SortAggregateExec(Partial)不同点:

    • updateExprs的不同
        val updateExprs = aggregateExpressions.map { e =>
        e.mode match {
          case Partial | Complete =>
            e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
          case PartialMerge | Final =>
            e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
        }
      }
    

    因为这里是Final阶段,所以更新语句是mergeExpressions,即

    protected def getMergeExpressions = Seq(
     /* sum = */ Add(sum.left, sum.right, useAnsiAdd),
     /* count = */ count.left + count.right
    

    生成对应的ExprCode类为:

    List(ExprCode(boolean sortAgg_isNull_13 = true;
         double sortAgg_value_13 = -1.0;
         if (!sortAgg_bufIsNull_1) {
         if (!sortAgg_exprIsNull_1_0) {
               sortAgg_isNull_13 = false; // resultCode could change nullability.
               sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0;
         }
         },sortAgg_isNull_13,sortAgg_value_13),
    
     ExprCode(boolean sortAgg_isNull_16 = true;
         long sortAgg_value_16 = -1L;
         
         if (!sortAgg_bufIsNull_2) {
         if (!sortAgg_exprIsNull_2_0) {
               sortAgg_isNull_16 = false; // resultCode could change nullability.
               sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0;
         }
       
         },sortAgg_isNull_16,sortAgg_value_16)))
    

    可以看到,对于sortAgg_value_16也就是COUNT的值,这里是+ sortAgg_expr_2_0,而在Partial部分则是+ 1,因为这是各个Task local操作完后的结果buffer
    sortAgg_isNull_13sortAgg_value_16 的最终结果还是会赋值给全局变量:

      sortAgg_bufIsNull_1 = sortAgg_isNull_13;
      sortAgg_bufValue_1 = sortAgg_value_13;
    
      sortAgg_bufIsNull_2 = sortAgg_isNull_16;
      sortAgg_bufValue_2 = sortAgg_value_16;
    
    SortAggregateExec(Final)的consume
    consume(ctx, resultVars)
    

    其中resultVarsExprCode(,sortAgg_isNull_4,sortAgg_value_4),这里包含了最终的AVG的结果值。
    其他的数据流向和之前的一样,

    • val rowVar = prepareRowVar(ctx, row, outputVars)
      该最终sql返回如下(和之前的一致):
         sortAgg_mutableStateArray_0[0].reset();
    
         sortAgg_mutableStateArray_0[0].zeroOutNullBytes();
         if (sortAgg_bufIsNull_0) {
           sortAgg_mutableStateArray_0[0].setNullAt(0);
         } else {
           sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0);
         }
    
         if (sortAgg_isNull_4) {
           sortAgg_mutableStateArray_0[0].setNullAt(1);
         } else {
           sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4);
         }
    
    WholeStageCodegenExec的doConsume.

    数据流和之前的一致,最终的生成的代码如下:
    append((sortAgg_mutableStateArray_0[0].getRow()));

    第二阶段wholeStageCodegen最终的代码如下:

    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIteratorForCodegenStage2(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ // codegenStageId=2
    /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
    /* 007 */   private Object[] references;
    /* 008 */   private scala.collection.Iterator[] inputs;
    /* 009 */   private boolean sortAgg_initAgg_0;
    /* 010 */   private boolean sortAgg_bufIsNull_0;
    /* 011 */   private long sortAgg_bufValue_0;
    /* 012 */   private boolean sortAgg_bufIsNull_1;
    /* 013 */   private double sortAgg_bufValue_1;
    /* 014 */   private boolean sortAgg_bufIsNull_2;
    /* 015 */   private long sortAgg_bufValue_2;
    /* 016 */   private scala.collection.Iterator inputadapter_input_0;
    /* 017 */   private boolean sortAgg_sortAgg_isNull_10_0;
    /* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] sortAgg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
    /* 019 */
    /* 020 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
    /* 021 */     this.references = references;
    /* 022 */   }
    /* 023 */
    /* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
    /* 025 */     partitionIndex = index;
    /* 026 */     this.inputs = inputs;
    /* 027 */
    /* 028 */     inputadapter_input_0 = inputs[0];
    /* 029 */     sortAgg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
    /* 030 */
    /* 031 */   }
    /* 032 */
    /* 033 */   private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0) throws java.io.IOException {
    /* 034 */     sortAgg_sortAgg_isNull_10_0 = true;
    /* 035 */     long sortAgg_value_10 = -1L;
    /* 036 */
    /* 037 */     if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_10_0 ||
    /* 038 */         sortAgg_bufValue_0 > sortAgg_value_10)) {
    /* 039 */       sortAgg_sortAgg_isNull_10_0 = false;
    /* 040 */       sortAgg_value_10 = sortAgg_bufValue_0;
    /* 041 */     }
    /* 042 */
    /* 043 */     if (!sortAgg_exprIsNull_0_0 && (sortAgg_sortAgg_isNull_10_0 ||
    /* 044 */         sortAgg_expr_0_0 > sortAgg_value_10)) {
    /* 045 */       sortAgg_sortAgg_isNull_10_0 = false;
    /* 046 */       sortAgg_value_10 = sortAgg_expr_0_0;
    /* 047 */     }
    /* 048 */
    /* 049 */     sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_10_0;
    /* 050 */     sortAgg_bufValue_0 = sortAgg_value_10;
    /* 051 */   }
    /* 052 */
    /* 053 */   private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException {
    /* 054 */     // initialize aggregation buffer
    /* 055 */     sortAgg_bufIsNull_0 = true;
    /* 056 */     sortAgg_bufValue_0 = -1L;
    /* 057 */     sortAgg_bufIsNull_1 = false;
    /* 058 */     sortAgg_bufValue_1 = 0.0D;
    /* 059 */     sortAgg_bufIsNull_2 = false;
    /* 060 */     sortAgg_bufValue_2 = 0L;
    /* 061 */
    /* 062 */     while ( inputadapter_input_0.hasNext()) {
    /* 063 */       InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
    /* 064 */
    /* 065 */       boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
    /* 066 */       long inputadapter_value_0 = inputadapter_isNull_0 ?
    /* 067 */       -1L : (inputadapter_row_0.getLong(0));
    /* 068 */       boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
    /* 069 */       double inputadapter_value_1 = inputadapter_isNull_1 ?
    /* 070 */       -1.0 : (inputadapter_row_0.getDouble(1));
    /* 071 */       boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);
    /* 072 */       long inputadapter_value_2 = inputadapter_isNull_2 ?
    /* 073 */       -1L : (inputadapter_row_0.getLong(2));
    /* 074 */
    /* 075 */       sortAgg_doConsume_0(inputadapter_row_0, inputadapter_value_0, inputadapter_isNull_0, inputadapter_value_1, inputadapter_isNull_1, inputadapter_value_2, inputadapter_isNull_2);
    /* 076 */       // shouldStop check is eliminated
    /* 077 */     }
    /* 078 */
    /* 079 */   }
    /* 080 */
    /* 081 */   protected void processNext() throws java.io.IOException {
    /* 082 */     while (!sortAgg_initAgg_0) {
    /* 083 */       sortAgg_initAgg_0 = true;
    /* 084 */       sortAgg_doAggregateWithoutKey_0();
    /* 085 */
    /* 086 */       // output the result
    /* 087 */       boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2;
    /* 088 */       double sortAgg_value_6 = -1.0;
    /* 089 */       if (!sortAgg_bufIsNull_2) {
    /* 090 */         sortAgg_value_6 = (double) sortAgg_bufValue_2;
    /* 091 */       }
    /* 092 */       boolean sortAgg_isNull_4 = false;
    /* 093 */       double sortAgg_value_4 = -1.0;
    /* 094 */       if (sortAgg_isNull_6 || sortAgg_value_6 == 0) {
    /* 095 */         sortAgg_isNull_4 = true;
    /* 096 */       } else {
    /* 097 */         if (sortAgg_bufIsNull_1) {
    /* 098 */           sortAgg_isNull_4 = true;
    /* 099 */         } else {
    /* 100 */           sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);
    /* 101 */         }
    /* 102 */       }
    /* 103 */
    /* 104 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
    /* 105 */       sortAgg_mutableStateArray_0[0].reset();
    /* 106 */
    /* 107 */       sortAgg_mutableStateArray_0[0].zeroOutNullBytes();
    /* 108 */
    /* 109 */       if (sortAgg_bufIsNull_0) {
    /* 110 */         sortAgg_mutableStateArray_0[0].setNullAt(0);
    /* 111 */       } else {
    /* 112 */         sortAgg_mutableStateArray_0[0].write(0, sortAgg_bufValue_0);
    /* 113 */       }
    /* 114 */
    /* 115 */       if (sortAgg_isNull_4) {
    /* 116 */         sortAgg_mutableStateArray_0[0].setNullAt(1);
    /* 117 */       } else {
    /* 118 */         sortAgg_mutableStateArray_0[0].write(1, sortAgg_value_4);
    /* 119 */       }
    /* 120 */       append((sortAgg_mutableStateArray_0[0].getRow()));
    /* 121 */     }
    /* 122 */   }
    /* 123 */
    /* 124 */   private void sortAgg_doConsume_0(InternalRow inputadapter_row_0, long sortAgg_expr_0_0, boolean sortAgg_exprIsNull_0_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_1_0, long sortAgg_expr_2_0, boolean sortAgg_exprIsNull_2_0) throws java.io.IOException {
    /* 125 */     // do aggregate
    /* 126 */     // common sub-expressions
    /* 127 */
    /* 128 */     // evaluate aggregate functions and update aggregation buffers
    /* 129 */     sortAgg_doAggregate_max_0(sortAgg_expr_0_0, sortAgg_exprIsNull_0_0);
    /* 130 */     sortAgg_doAggregate_avg_0(sortAgg_exprIsNull_1_0, sortAgg_expr_1_0, sortAgg_exprIsNull_2_0, sortAgg_expr_2_0);
    /* 131 */
    /* 132 */   }
    /* 133 */
    /* 134 */   private void sortAgg_doAggregate_avg_0(boolean sortAgg_exprIsNull_1_0, double sortAgg_expr_1_0, boolean sortAgg_exprIsNull_2_0, long sortAgg_expr_2_0) throws java.io.IOException {
    /* 135 */     boolean sortAgg_isNull_13 = true;
    /* 136 */     double sortAgg_value_13 = -1.0;
    /* 137 */
    /* 138 */     if (!sortAgg_bufIsNull_1) {
    /* 139 */       if (!sortAgg_exprIsNull_1_0) {
    /* 140 */         sortAgg_isNull_13 = false; // resultCode could change nullability.
    /* 141 */
    /* 142 */         sortAgg_value_13 = sortAgg_bufValue_1 + sortAgg_expr_1_0;
    /* 143 */
    /* 144 */       }
    /* 145 */
    /* 146 */     }
    /* 147 */     boolean sortAgg_isNull_16 = true;
    /* 148 */     long sortAgg_value_16 = -1L;
    /* 149 */
    /* 150 */     if (!sortAgg_bufIsNull_2) {
    /* 151 */       if (!sortAgg_exprIsNull_2_0) {
    /* 152 */         sortAgg_isNull_16 = false; // resultCode could change nullability.
    /* 153 */
    /* 154 */         sortAgg_value_16 = sortAgg_bufValue_2 + sortAgg_expr_2_0;
    /* 155 */
    /* 156 */       }
    /* 157 */
    /* 158 */     }
    /* 159 */
    /* 160 */     sortAgg_bufIsNull_1 = sortAgg_isNull_13;
    /* 161 */     sortAgg_bufValue_1 = sortAgg_value_13;
    /* 162 */
    /* 163 */     sortAgg_bufIsNull_2 = sortAgg_isNull_16;
    /* 164 */     sortAgg_bufValue_2 = sortAgg_value_16;
    /* 165 */   }
    /* 166 */
    /* 167 */ }
    
  • 相关阅读:
    OC-NSString
    Python项目开发:Flask基于Python的天气数据可视化平台
    马士兵老师JVM调优(修订版)
    【控制】滑模控制,小例子,有程序有结果图
    MASM-环境搭建篇
    Remove和RemoveLast用法
    Web大学生网页作业成品——美食餐饮网站设计与实现(HTML+CSS+JavaScript)
    C++多态、虚函数、纯虚函数、抽象类
    【机器学习】逻辑回归LR的推导及特性是什么,面试回答?
    Java方法的重载/方法的内存/基本数据类型与引用数据类型/方法的值传递
  • 原文地址:https://blog.csdn.net/monkeyboy_tech/article/details/126950809