• 手撕SparkSQL五大JOIN的底层机制


    关联形式(Join Types)都有哪些

    我个人习惯还是从源码里面定义入手,一方面如果有调整,大家知道怎么去查,另一方面来说,没有什么比起源码的更加官方的定义了。SparkSQL中的关于JOIN的定义位于
    org.apache.spark.sql.catalyst.plans.JoinType,按照包的划分,JOIN其实是执行计划的一部分。
    在这里插入图片描述

    具体的定义可以在JoinType的伴生对象中apply方法有构造。

    object JoinType {
      def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match {
        case "inner" => Inner
        case "outer" | "full" | "fullouter" => FullOuter
        case "leftouter" | "left" => LeftOuter
        case "rightouter" | "right" => RightOuter
        case "leftsemi" | "semi" => LeftSemi
        case "leftanti" | "anti" => LeftAnti
        case "cross" => Cross
        case _ =>
          val supported = Seq(
            "inner",
            "outer", "full", "fullouter", "full_outer",
            "leftouter", "left", "left_outer",
            "rightouter", "right", "right_outer",
            "leftsemi", "left_semi", "semi",
            "leftanti", "left_anti", "anti",
            "cross")
    
          throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
            "Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    当然,我们其实可以很清楚看出来为什么我们平时说left outer join 和left join 其实是一样的。

    关联形式关键字
    内关联inner
    外关联outer、full、fullouter、full_outer
    左关联leftouter、left、left_outer
    右关联rightouter、right、right_outer
    左半关联leftsemi、left_semi、semi
    左逆关联leftanti、left_anti、anti
    交叉连接(笛卡尔积)cross

    前面几个大家比较熟悉,我说一下后面三个,后面三个其实用得挺多的,只是我们经常是
    用一些其他语法表达而已,因为后面我们演示数据,提前建立表

    import spark.implicits._
        import org.apache.spark.sql.DataFrame
    
        // 学生表
        val seq = Seq((1, "小明", 28, "男","二班"), (2, "小丽", 22, "女","四班"), (3, "阿虎", 24, "男","三班"), (5, "张强", 18, "男","四班"))
        val students: DataFrame = seq.toDF("id", "name", "age", "gender","class")
        students.show(10,false);
        students.createTempView("students")
        // 班级表
        val seq2 = Seq(("三班",3),("四班",4),("三班",1))
        val classes:DataFrame = seq2.toDF("class_name", "id")
        classes.show(10,false)
        classes.createTempView("classes")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    +---+----+---+------+-----+
    |id |name|age|gender|class|
    +---+----+---+------+-----+
    |1  |小明|28 ||二班 |
    |2  |小丽|22 ||四班 |
    |3  |阿虎|24 ||三班 |
    |5  |张强|18 ||四班 |
    +---+----+---+------+-----+
    
    +---+----+---+------+-----+
    | id|name|age|gender|class|
    +---+----+---+------+-----+
    |  2|小丽| 22|| 四班|
    |  3|阿虎| 24|| 三班|
    |  5|张强| 18|| 四班|
    +---+----+---+------+-----+
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    left_semi join

    这两个刚好是一对,一个是要剔除,一个是要保留
    我们经常有这种需求,我们要查询班级表中存在的学生姓名,我们会这样写SQL

    select id,name,age,gender,class from students where class in (select class_name from classes group by class_name)
    
    • 1

    是一种存在需求,转换成我们的left_semi就是

    select id,name,age,gender,class from students left semi join classes on students.class=classes.class_name
    
    • 1

    当然,也不会忘记了我们的代码形式

    val leftsemiDF: DataFrame = students.join(classes, students("class") === classes("class_name"), "leftsemi")
    
    • 1

    三次运算结果都是一致的,因为班级信息里面没有二班,所以只有三班和4班的信息,这一类需求是我们实现exists 和 一些where 条件中in的时候大量使用

    
    +---+----+---+------+-----+
    | id|name|age|gender|class|
    +---+----+---+------+-----+
    |  2|小丽| 22|| 四班|
    |  3|阿虎| 24|| 三班|
    |  5|张强| 18|| 四班|
    +---+----+---+------+-----+
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    left_anti join

    left_anti其实是和semi相反的工作前面是实现了保留,这个则是去掉,我们其实简单换成left_anti 可以看到类似的效果,我们希望看到不在班级表里面的学生,我们会这样写sql

    select id,name,age,gender,class from students where class not in (select class_name from classes group by class_name)
    
    • 1

    转换成我们的anti join则是

    select id,name,age,gender,class from students left anti join classes on students.class=classes.class_name
    
    • 1

    当然,也有代码版本

     val leftantiDF: DataFrame = students.join(classes, students("class") === classes("class_name"), "leftanti")
    
    • 1

    结果如下,只有2班的小明了:

    +---+----+---+------+-----+
    | id|name|age|gender|class|
    +---+----+---+------+-----+
    |  1|小明| 28|| 二班|
    +---+----+---+------+-----+
    
    • 1
    • 2
    • 3
    • 4
    • 5

    cross join

    这个其实是笛卡尔积,我们经常忘记写join条件的时候,这种就是笛卡尔积了,
    因为这个操作是很容易把程序搞崩的,所以要加上配置
    spark.sql.crossJoin.enabled=true

    spark.conf.set("spark.sql.crossJoin.enabled", "true")
    students.join(classes).show(10,false)
    
    • 1
    • 2

    结果如下:

    +---+----+---+------+-----+----------+---+
    |id |name|age|gender|class|class_name|id |
    +---+----+---+------+-----+----------+---+
    |1  |小明|28 ||二班 |三班      |3  |
    |1  |小明|28 ||二班 |四班      |4  |
    |1  |小明|28 ||二班 |三班      |1  |
    |2  |小丽|22 ||四班 |三班      |3  |
    |2  |小丽|22 ||四班 |四班      |4  |
    |2  |小丽|22 ||四班 |三班      |1  |
    |3  |阿虎|24 ||三班 |三班      |3  |
    |3  |阿虎|24 ||三班 |四班      |4  |
    |3  |阿虎|24 ||三班 |三班      |1  |
    |5  |张强|18 ||四班 |三班      |3  |
    +---+----+---+------+-----+----------+---+
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    JOIN的实现机制

    Spark中的JOIN对应BaseJoinExec的五个子类,他们分别是 BroadcastHashJoinExec、BroadcastNestedLoopJoinExec、ShuffledHashJoinExec、SortMergeJoinExec、CartesianProductExec,源码关系如下:
    在这里插入图片描述

    可能大家平时没有太关注这里头有啥联系,但是当我们把这些摆在一起的时候我们其实很明显发现Broadcast和Hash就出现了个,我们可以大胆猜测,这里头有必然的联系,其实Broadcast和Shuffled其实是数据分发的形式,SortMergeJoinExec其实也是通过Shuffle的分发,只是类取名字的时候没有写成ShuffledSortMergeJoinExec 看着是有点长吧,还一点就是走了归并排序其实就是不会走广播了。我们按照分发方式可以整理出一个小表格:

    分发方式关键字
    无分发CartesianProductExec
    BroadcastBroadcastNestedLoopJoinExec
    BroadcastBroadcastHashJoinExec
    ShuffledShuffledHashJoinExec
    ShuffledSortMergeJoinExec

    CartesianProductExec 无分发

    为了搞明白计算原理,我们通过源码来研究研究。首其实可以想得到,笛卡尔积投影是直接把数据的所有行都按照交叉膨胀,这个事情直接在Map端组合完成分发就好了,代码里面其实也是这样子的,我们一起看看,我把CartesianProductExec计算部分拿出来。

    override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
        val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold)
        val partition = split.asInstanceOf[CartesianPartition]
        rdd2.iterator(partition.s2, context).foreach(rowArray.add)
        // Create an iterator from rowArray
        def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
        val resultIter =
          for (x <- rdd1.iterator(partition.s1, context);
               y <- createIter()) yield (x, y)
        CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
          resultIter, rowArray.clear())
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    我们要抠关键部分,其实核心点就是,x,y各自其实就是作为两边的元素,我们理解一下我们操作集合的时候yide的操作。

     val resultIter =
          for (x <- rdd1.iterator(partition.s1, context);
               y <- createIter()) yield (x, y)
    
    • 1
    • 2
    • 3

    这个其实就是scala中实现二重循环的操作,spark上面的逻辑可以类比下面的逻辑,我把学生join笛卡尔积的操作还原如下:

     val students=Array((1,"小明"),(2,"小丽"))
        val classes= Array("三班","二班")
        val result= for(student <-students; clazz <- classes) yield (student,clazz)
        result.foreach(print)
    
    • 1
    • 2
    • 3
    • 4

    结果如下:

    ((1,小明),三班)((1,小明),二班)((2,小丽),三班)((2,小丽),二班)
    
    
    • 1
    • 2

    发现没有,其实就是我们需要的结果了,注意哦,Spark源码就是这样的,是不是信心倍增。

    BroadcastNestedLoopJoinExec

    这个其实就是字面上的含义,广播+嵌套循环实现join,我们一直在说,广播其实是一种分发方式,在我们之前的文章也有说到,其实广播来说,我们在rdd执行的时候,就直接可以当成拿到本地变量而已,我还是把核心代码拿出来:

    protected override def doExecute(): RDD[InternalRow] = {
        val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
    
        val resultRdd = (joinType, buildSide) match {
          case (_: InnerLike, _) =>
            innerJoin(broadcastedRelation)
          case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
            outerJoin(broadcastedRelation)
          case (LeftSemi, _) =>
            leftExistenceJoin(broadcastedRelation, exists = true)
          case (LeftAnti, _) =>
            leftExistenceJoin(broadcastedRelation, exists = false)
          case (_: ExistenceJoin, _) =>
            existenceJoin(broadcastedRelation)
          case _ =>
            /**
             * LeftOuter with BuildLeft
             * RightOuter with BuildRight
             * FullOuter
             */
            defaultJoin(broadcastedRelation)
        }
    
        val numOutputRows = longMetric("numOutputRows")
        resultRdd.mapPartitionsWithIndexInternal { (index, iter) =>
          val resultProj = genResultProjection
          resultProj.initialize(index)
          iter.map { r =>
            numOutputRows += 1
            resultProj(r)
          }
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    最前面 broadcastedRelation,其实就是从广播的变量中获取到数据,这个就是广播的操作了,剩下的就是NestedLoop的事情了,我们注意到代中的(joinType, buildSide) match条件,是按照操作类型不同去实现,我们一起看看innser join 的操作:

    private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
        streamed.execute().mapPartitionsInternal { streamedIter =>
          val buildRows = relation.value
          val joinedRow = new JoinedRow
    
          streamedIter.flatMap { streamedRow =>
            val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r))
            if (condition.isDefined) {
              joinedRows.filter(boundCondition)
            } else {
              joinedRows
            }
          }
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    这个操作的意思是把集合打平,关键点就是streamedIter.flatMap()中的操作,这部分就是嵌套循环,这个也就是为什么叫做NestedLoop的原因,其实就是嵌套循环的意思,我把这个操作按照我们数据集等价实现一遍:

    val students=Array((1,"小明"),(2,"小丽"))
    val classes= Array("三班","二班")
    students.flatMap(student=>{
         classes.map(clazz=>{
           print(student._1,student._2,clazz)
         })
       })
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    结果如下:

    (1,小明,三班)(1,小明,二班)(2,小丽,三班)(2,小丽,二班)
    
    • 1

    我们很明显看得出时间复杂度是M*N,看到这里。

    BroadcastHashJoinExec

    有了前面的基础,对于Broadcast类型的操作,我们可以进一步归纳,三部曲
    1、从广播变量中取值
    2、完成关联操作
    3、输出结果
    有了这些操作,我们可以预判代码了

      protected override def doExecute(): RDD[InternalRow] = {
    	val numOutputRows = longMetric("numOutputRows")
        val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
        if (isNullAwareAntiJoin) {
        		//Anti join是尝试解决子查询反嵌套(Subquery Unnesting)中和NULl值相关的各种问题,这里不展开,也是
        		//为了代码少一些
        } else {
          streamedPlan.execute().mapPartitions { streamedIter =>
            val hashed = broadcastRelation.value.asReadOnlyCopy()
            TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
            join(streamedIter, hashed, numOutputRows)
          }
        }
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    广播方式就是从广播变量中获取我们要的结果
    val broadcastRelation = buildPlan.executeBroadcastHashedRelation
    我们需要关注else 里面的内容,val hashed = broadcastRelation.value.asReadOnlyCopy()其实就是拿到了我们hash的清单,处理之后会在后面的join操作中进一步处理,我们需要深入JOIN内部

    protected def join(
          streamedIter: Iterator[InternalRow],
          hashed: HashedRelation,
          numOutputRows: SQLMetric): Iterator[InternalRow] = {
    
        val joinedIter = joinType match {
          case _: InnerLike =>
            innerJoin(streamedIter, hashed)
          case LeftOuter | RightOuter =>
            outerJoin(streamedIter, hashed)
          case LeftSemi =>
            semiJoin(streamedIter, hashed)
          case LeftAnti =>
            antiJoin(streamedIter, hashed)
          case _: ExistenceJoin =>
            existenceJoin(streamedIter, hashed)
          case x =>
            throw new IllegalArgumentException(
              s"HashJoin should not take $x as the JoinType")
        }
    
        val resultProj = createResultProjection
        joinedIter.map { r =>
          numOutputRows += 1
          resultProj(r)
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    JOIN内部的主流程是按照JOIN类型各自处理,返回JOIN之后的结果,我们还是以Inner为例,分析一下实现的过程

    private def innerJoin(
          streamIter: Iterator[InternalRow],
          hashedRelation: HashedRelation): Iterator[InternalRow] = {
        val joinRow = new JoinedRow
        val joinKeys = streamSideKeyGenerator()
    
        if (hashedRelation == EmptyHashedRelation) {
          Iterator.empty
        } else if (hashedRelation.keyIsUnique) {
          streamIter.flatMap { srow =>
            joinRow.withLeft(srow)
            val matched = hashedRelation.getValue(joinKeys(srow))
            if (matched != null) {
              Some(joinRow.withRight(matched)).filter(boundCondition)
            } else {
              None
            }
          }
        } else {
          streamIter.flatMap { srow =>
            joinRow.withLeft(srow)
            val matches = hashedRelation.get(joinKeys(srow))
            if (matches != null) {
              matches.map(joinRow.withRight).filter(boundCondition)
            } else {
              Seq.empty
            }
          }
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    我们关注到核心操作,从hash表中取到我们想要的元素,实现join,同样的,用我们的数据集实现一下

     	import scala.collection.mutable.HashMap
        val students = Array((1, "小明", "002"), (2, "小丽", "003"))
        val hashClass = HashMap[String, String]("003" -> "三班", "002" -> "二班")
        students.map(student => (student, hashClass.get(student._3))).foreach(print)
    
    • 1
    • 2
    • 3
    • 4

    因为本身的逻辑就行整个遍历一遍重新组合结果,时间复杂度是O(N)

    ((1,小明,002),Some(二班))((2,小丽,003),Some(三班))
    
    • 1

    阶段性总结

    前面介绍的三种JOIN方式其实已经可以完全实现所有的JOIN操作了,但是这些操作有一个特点,作为主表我们可以分成不同的Partition上面执行,但是从表我们其实是清一色作为Executor本地方式执行的,因为我们的Task是分布在很多集群上运行的,所有我们为了让所有的节点都有这份数据,所有是往所有节点都分发一次,这也是为啥叫做广播的原因。

    这些方式在数据量不大的时候是很高效的,这个数据量的规模可以是10万级到百万级不等,也就是说其实可以控制的,源码中的定义如下,我们可以看到其实后面给我们有一个默认值10MB

        val AUTO_BROADCASTJOIN_THRESHOLD = buildConf("spark.sql.autoBroadcastJoinThreshold")
        .doc("...解释信息")
        .version("1.1.0")
        .bytesConf(ByteUnit.BYTE)
        .createWithDefaultString("10MB")
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    有这么一句描述 By setting this value to -1 broadcasting can be disabled,就是给-1可以关闭这个数值,这种10M的概念一般就是几万以内,和你本身字段数量还有数据内容有关系,如果不满足可以调整。
    我们做如下设置,就可以不打开广播了

    spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
    
    • 1

    还一个情况是大家网络上看到说广播方式数据量小的时候其实适合,大的时候慢。其实这里在生产中的实际情况是,数据量比较大的时候他是跑不出来,不是慢的问题,因为一个任务有时候执行下来本身就是要大半个小时,但是跑一个小时之后来个失败,这个才是忍受不了的,生产环境下游又有依赖,所以这个时候我们不再是追求快了,而是都在求神拜佛就行只要跑出来就可以,哪怕是再等等,但是别失败。这才是实际的情况。基于此,我们才会引入ShuffledHashJoinExec和SortMergeJoinExec的计算方式,大家注意,这两种方式是为了实现更大规模数据量的JOIN而产生的,就整体时间效率上远不如前面的方式,但是最大的好处是可以出结果呀,加上企业实际生产过程中数据量其实都是很庞大的,所以这两种方式才是在生产上大量存在的操作方式。

    ShuffledHashJoinExec

    ShuffledHashJoinExec方式其实在join操作中获取从表信息还是从Hash中获取,这里的差别在于,我们本来需要做广播的表太大了,所以我们需要把广播的表通Shuffle的方式把一个大表分解成小表生成hash,是怎么个原理呢。首先我们看实现源码:

    protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
        streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
          val hashed = buildHashedRelation(buildIter)
          joinType match {
            case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
            case _ => join(streamIter, hashed, numOutputRows)
          }
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    ShuffledHashJoinExec的后面部分其实是复用了前面的BroadcastHashJoinExec的操作,实现来说还是需要获取到一个HashedRelation,这个也是我们的构建HashTable的部分,后面就是会在本地执行join操作了,主要差别是来自获取HashedRelation的时候,前者是从广播变量中获取,这里不再是了:

    def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
        val buildDataSize = longMetric("buildDataSize")
        val buildTime = longMetric("buildTime")
        val start = System.nanoTime()
        val context = TaskContext.get()
        val relation = HashedRelation(
          iter,
          buildBoundKeys,
          taskMemoryManager = context.taskMemoryManager(),
          // Full outer join needs support for NULL key in HashedRelation.
          allowsNullKey = joinType == FullOuter)
        buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
        buildDataSize += relation.estimatedSize
        // This relation is usually used until the end of task.
        context.addTaskCompletionListener[Unit](_ => relation.close())
        relation
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    核心部分其实是来自,这部分其实就是一个Shuffle的迭代结果,也就是说我们获取这部分Hash的时候是按照buildBoundKeys作为范围内的获取,不再是整个表的范围了

    val relation = HashedRelation(
         iter,
         buildBoundKeys,
         taskMemoryManager = context.taskMemoryManager(),
         // Full outer join needs support for NULL key in HashedRelation.
         allowsNullKey = joinType == FullOuter)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    当然,从理解上还是简单的,因为Spark在实现设计上非常巧妙,我们还是可以当成一个Map来看待,只不过这个时候的数据范围是局部地区。
    我们同样做点实战例子来理解计算过程,因为比起前面还是多一些步骤,我们画图理解一下:


    学生信息作为主表目标就是找到和班级编号匹配的主表,在此之前,班级编号做一次partition操作,会把数据分布在不同分区里面去,与此同时,班级信息也进行partition,这样相同班级编号的信息会落到同一个partition中去,针对单独的partition就是实现和之前的HashJoin一样的操作了。要注意的是partition就是Shuffle实现的,我们都知道shuffle的时候是需要定义一个hashpartition的操作,所以这个操作其实是有两次的hash,第一次把数据进行分区,第二次,实现关联操作。我们把整个过程用代码模拟出来,我这里分区操作不会那么复杂,仅仅按照%2的方式分发:

    //对班级编号进行分区
      def partition(key:String):Integer={
        val keyNum= Integer.valueOf(key)
        keyNum%2
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5

    接下来是模拟Shuffle的操作,因为我们是需要把班级信息通过shuffle的方式分发,我们实现了把多个HashMap分发到了partitions内部

     def shuffleHash(classes:Array[(String, String)]):mutable.HashMap[Integer,mutable.HashMap[String,(String, String)]]={
          val partitions=new mutable.HashMap[Integer,mutable.HashMap[String,(String, String)]]();
          for(clazz <- classes){
            val hashKey=partition(clazz._1)
            val map=partitions.getOrElse(hashKey,mutable.HashMap[String,(String, String)]())
            map += (clazz._1->clazz)
            partitions += hashKey->map
          }
          partitions
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    最后,我们实现真正的JOIN操作:

    val students = Array[(Integer,String,String)]((1, "小黑", "001"),(2, "小明", "002"), (3, "小丽", "003"),(4, "小红", "004"))
        val classes: Array[(String, String)] = Array(("001", "一班"), ("002", "一班"), ("003", "三班"), ("004", "四班"))
    
       val partitions= shuffleHash(classes)
        students.map(student=>{
         val partitionId= partition(student._3)
         val  classHashMap= partitions.get(partitionId).get
         val clazz= classHashMap.get(student._3)
          println("分区:"+partitionId,"学生:"+student,"班级:"+clazz.get)
        })
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    查看结果:

    (分区:1,学生:(1,小黑,001),班级:(001,一班))
    (分区:0,学生:(2,小明,002),班级:(002,一班))
    (分区:1,学生:(3,小丽,003),班级:(003,三班))
    (分区:0,学生:(4,小红,004),班级:(004,四班))
    
    • 1
    • 2
    • 3
    • 4

    这个也能理解,为什么Shuffle操作一定要把相同的数据分发到相同的分区里面去,其实这个就是一个分治的算法,是不是感觉也没那么神秘了~~

    SortMergeJoinExec

    我们其实注意到,算法的不断演进就是为了实现不同数据规模的情况,SortMergeJoinExec的话我们同样适用,也许是名字会太长的关系,如果我们补充完全,按照数据生成方式,我们可以命名为ShuffledSortMergeJoinExec,这就是说需要通过Shuffle的方式生成,同时在合并的时候走的是SortMerge,我们可以从类的继承关系上看出来。

    case class SortMergeJoinExec(
        leftKeys: Seq[Expression],
        rightKeys: Seq[Expression],
        joinType: JoinType,
        condition: Option[Expression],
        left: SparkPlan,
        right: SparkPlan,
        isSkewJoin: Boolean = false) extends ShuffledJoin {
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    这个是因为,不管是前面的Hash方式做Shuffle,我们都用到了一个classHashMap的类,这个操作时间复杂度O(1),快得很,但是架不住数据多呀,数据量比较庞大的时候classHashMap其实是转不下的,所以我还是那句话,这种时候其实是跑不出来,因为内存会打爆,并不是真的就慢的原因,因为按照慢来说至少等一等可以,但是如果超过了本身这种处理方式的能力的话,整个计算进行不下去,我才迫使我们不得不找到新的计算方式,可能性能上确实没那么快,但是至少跑得出来。

    val  classHashMap= partitions.get(partitionId).get 受内存局限
         val clazz= classHashMap.get(student._3)
    
    • 1
    • 2

    SortMerge这个是大家的印象应该是既熟悉又陌生,因为大量的作业都是用这种方式的,所以老是可以看到,但是陌生是因为也不知道里头什么个机制,我们今天来倒腾倒腾。前面我也说过Shuffle是给我们把数据按照相同的HashPartition算法分发到相同的分区中去,这个和前面操作是完全一样的,SortMerge就是解决怎么把两边的数据JOIN在一起的问题,所以我们重点关注一下SortMerge的操作。还是要关注源码中的执行操作,我们需要注意到,这一次的输入内容和前面是有差别的leftIter, rightIter是两个迭代器的输入操作了,不再是Hash,所以本身这种操作上算法是解决两个list的合并操作

    protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
        val spillThreshold = getSpillThreshold
        val inMemoryThreshold = getInMemoryThreshold
        left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
          val boundCondition: (InternalRow) => Boolean = {
            condition.map { cond =>
              Predicate.create(cond, left.output ++ right.output).eval _
            }.getOrElse {
              (r: InternalRow) => true
            }
          }
          ...
          }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    doExecute()下面的代码其实是比较长的,这里刚好一起学习学习这种长代码的研究思路,
    首先,我们要知道doExecute(): RDD[InternalRow] 一定需要返回一个RDD,这个是我们需要的结果,那么在接下来的代码就一定有地方在返回,代码比较长的时候我们先折叠一下长的部分:

    折叠之后我们就可以看清楚执行的主干了,整个的逻辑是对应不同的JOIN类型做了一个
    new RowIterator()的操作,而RowIterator.toScala就可以返回我们RDD[InternalRow] 了,顺着这个思路,我们其实可以梳理出不同JOIN对应的而RowIterator类型

    JOIN类型RowIterator实现
    InnerLikenew RowIterator 重写了advanceNext和 getRow
    LeftOuterLeftOuterIterator
    RightOuterRightOuterIterator
    FullOuterFullOuterIterator
    LeftSeminew RowIterator 重写了advanceNext和 getRow
    LeftAntinew RowIterator 重写了advanceNext和 getRow
    ExistenceJoinnew RowIterator 重写了advanceNext和 getRow

    ok到了这一步,我们对整个的返回就很清楚了,所以整个实现的触发入口其实是
    RowIterator.toScala我们进一步查看RowIterator的实现

    abstract class RowIterator {
      def advanceNext(): Boolean
      def getRow: InternalRow
      def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5

    RowIterator是一个抽象类,toScala里面其实是new RowIteratorToScala(this)做了这件事情,具体RowIteratorToScala的实现如下:

    private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] {
     private [this] var hasNextWasCalled: Boolean = false
     private [this] var _hasNext: Boolean = false
     override def hasNext: Boolean = {
       // Idempotency:
       if (!hasNextWasCalled) {
         _hasNext = rowIter.advanceNext()
         hasNextWasCalled = true
       }
       _hasNext
     }
     override def next(): InternalRow = {
       if (!hasNext) throw QueryExecutionErrors.noSuchElementExceptionError()
       hasNextWasCalled = false
       rowIter.getRow
     }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    看到这里,我们其实很清晰了,整个过程其实就是在实现一个Iterator操作,我们都知道,RDD本身就是一个Iterator在做RDD的迭代时候,核心方法就是触发hasNext和 next方法,这两个方法也是Iterator的入口,我们很清楚看到在hasNext调用了advanceNext操作,而在
    next中我们从传入的rowIter: RowIterator调用了getRow操作,这两个方法就是在前面做了重写的地方,整个实现结构我们就串联起来了。接下来我们只需要关注advanceNext()的实现:

    
                  while (rightMatchesIterator != null) {
                    if (!rightMatchesIterator.hasNext) {
                      if (smjScanner.findNextInnerJoinRows()) {
                        currentRightMatches = smjScanner.getBufferedMatches
                        currentLeftRow = smjScanner.getStreamedRow
                        rightMatchesIterator = currentRightMatches.generateIterator()
                      } else {
                        currentRightMatches = null
                        currentLeftRow = null
                        rightMatchesIterator = null
                        return false
                      }
                    }
                    joinRow(currentLeftRow, rightMatchesIterator.next())
                    if (boundCondition(joinRow)) {
                      numOutputRows += 1
                      return true
                    }
                  }
                  false
                }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    看到这里,我们可以清楚看到整个merge逻辑了,rightMatchesIterator和smjScanner就是我们做join之后左右两边的Iterator, 而joinRow(currentLeftRow, rightMatchesIterator.next())就是把当前获取到的数据行拼接起来,这里是因为前提是我们的数据行在两边都是已经做了排序的,所以只需要把迭代器往前面移动即可,并不需要做Hash的时候的查找操作。还是按照我们的优良传统,我来一波小代码实现一把:

     def merge(students: Array[(Integer,String,String)],classes: Array[(String, String)]):Unit= {
        val stuIterator=students.iterator
        val claIterator=classes.iterator
        var curRow=claIterator.next()//当前行
        var stuRow=stuIterator.next()  //获取当前行的数据
    
        while (stuRow !=null ) {
          var needNext=true
           if(stuRow._3==curRow._1){
             //学生的班号和班级编号相等,就join起来
             println("学生:"+stuRow,"班级:"+curRow)
           }else{
             //匹配不上的情况,班号往后移动
             curRow=claIterator.next()
             needNext=false
           }
          if(needNext ){
            if(stuIterator.hasNext){
              stuRow=stuIterator.next()
            }else{
              stuRow=null
            }
          }
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    merge实现的是对已经排序之后的集合进行排序,所以我们在输入的时候要保证一下顺序

     val students = Array[(Integer,String,String)]((1, "小黑", "001"),(2, "小明", "002"), (3, "小丽", "003"),(4, "小红", "004"))
        val classes: Array[(String, String)] = Array(("001", "一班"), ("002", "二班"), ("003", "三班"), ("004", "四班"))
        merge(students,classes)
    
    • 1
    • 2
    • 3

    结果如下:

    (学生:(1,小黑,001),班级:(001,一班))
    (学生:(2,小明,002),班级:(002,二班))
    (学生:(3,小丽,003),班级:(003,三班))
    (学生:(4,小红,004),班级:(004,四班))
    
    • 1
    • 2
    • 3
    • 4

    我们图形化展示一下

    总结

    小代码虽然有点糙,但是那个是精华^^

  • 相关阅读:
    自动驾驶:未来的道路上的挑战与机遇
    SA8155 QNX 命令
    【unity小技巧】适用于任何 2d 游戏的钥匙门系统和buff系统——UnityEvent的使用
    实用笔记-java配置
    ubuntu访问github慢
    mac电脑版MATLAB R2023b for Mac中文激活版
    mybatis-plus根据指定条件批量更新
    Smale 论文列表: 粒计算方向,特别是属性约简
    能解决 80% 故障的排查思路
    带你了解树的全家桶(BST树到AVL树到B树到B+树)
  • 原文地址:https://blog.csdn.net/zhuxuemin1991/article/details/125962484