• SPark学习笔记:11 SparkSQL 的用户自定义函数UDF、UDAF、UDTF


    UDF 用户自定义函数(一对一)

    说明

    UDF输入一条记录,输出一条记录,一对一的关系,有点类似于map算子,是一对一的关系

    使用

    UDF的使用有两种方式,一种方式是在SQL中使用,另一种方式是在DSL方式使用

    • 使用SQL的方式
    val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)
    
    sparkSession.sql("select myconcat3(id,timestamp) as newid,temperature")
    
    • 1
    • 2
    • 3
    • 在DSL中使用
    val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)
    
    df.select(myconcat3($"id",$"timestamp") as "newid",$"temperature").show()
    
    • 1
    • 2
    • 3

    实现方式

    UDF函数有3中实现方式:

    • 使用匿名函数
    val myconcat = sparkSession.udf.register("myconcat",(data1:String,data2:Long)=>{
      data1.concat(data2.toString)
    })
    
    • 1
    • 2
    • 3
    • 使用udf函数实现
    //引入udf方法
    import org.apache.spark.sql.functions.{col, udf}
    val myconcat2 = udf[String,String,Long]((data1:String,data2:Long)=>{
      data1.concat(data2.toString)
    })
    
    • 1
    • 2
    • 3
    • 4
    • 5

    说明: udf函数可以有多个输入参数,如上述我们实现的是两个输入参数则udf的原型是udf[R,T1,T2]

    R表示返回值的类型,T1表示第一个参数的类型,T2表示第2个参数的类型
    
    • 1
    • 继承Function函数接口的方式
    class MyConcat extends Function2[String,Long,String] with Serializable{
    override def apply(v1: String, v2: Long): String = {
      v1.concat(v2.toString)
    }
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    说明:Function2[T1,T2,R]

    • 因为我们要实现的UDF有两个输入参数,所以需要继承的是Function2
    • T1 表示第一个参数的类型,T2表示第二个参数的类型
    • R表示是返回值的类型
    • 继承Function Trait的时候,还需要继承Serializable接口,不然会报错。

    完整示例

    package com.hjt.yxh.hw.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions.{col, udf}
    
    object UDFApp {
      def main(args: Array[String]): Unit = {
        val conf:SparkConf = new SparkConf()
        conf.setMaster("local[*]").setAppName("UDFApp")
    
        val sparkSession:SparkSession = SparkSession.builder()
          .config(conf)
          .config("spark.sql.legacy.charVarcharAsString",true)
          .getOrCreate()
    
        import sparkSession.implicits._
    
        val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
        val df = sparkSession.read
          .format("csv")
          .schema("id VARCHAR(32),timestamp BIGINT,temperature DECIMAL(5,2)")
          .load(inpath)
    
        df.show(false)
    
        //方式一,使用匿名函数实现
        val myconcat = sparkSession.udf.register("myconcat",(data1:String,data2:Long)=>{
          data1.concat(data2.toString)
        })
        df.select(myconcat($"id",$"timestamp") as "newid",$"temperature").show()
    
        //方式二、使用udf函数实现
        val myconcat2 = udf[String,String,Long]((data1:String,data2:Long)=>{
          data1.concat(data2.toString)
        })
    
        sparkSession.udf.register("myconcat2",myconcat2)
        df.select(myconcat2($"id",$"timestamp") as "newid",$"temperature").show()
    
        //方式三、继承类的方式
        val myconcat3 = sparkSession.udf.register("myconcat3",new MyConcat)
    
        df.select(myconcat3($"id",$"timestamp") as "newid",$"temperature").show()
    
        sparkSession.close()
      }
    
      class MyConcat extends Function2[String,Long,String] with Serializable{
        override def apply(v1: String, v2: Long): String = {
          v1.concat(v2.toString)
        }
      }
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55

    Tips:我们在指定schema中使用VARCHAR类型时会报错,因为spark默认是没有开启VARCHAR或者CHAR类型支持的,需要设置一下参数spark.sql.legacy.charVarcharAsString为true

    val sparkSession:SparkSession = SparkSession.builder()
      .config(conf)
      .config("spark.sql.legacy.charVarcharAsString",true)
      .getOrCreate()
    
    • 1
    • 2
    • 3
    • 4

    UDAF 用户自定义聚合函数(多对一)

    说明

    UDAF是用户自定义聚合函数,一次输入多行做聚合运算输出一个聚合值作为输出结果。
    image

    使用

    UDF的使用有两种方式,一种方式是在SQL中使用,另一种方式是在DSL方式使用

    • 使用SQL的方式
    val myavg = udaf(new MyAggregator)
    sparkSession.udf.register("myavg",myavg)
    sparkSession.sql("select id,myavg(temperature) from sensor group by id ").show()
    
    • 1
    • 2
    • 3
    • 在DSL中使用
    val myavg = udaf(new MyAggregator)
    sparkSession.udf.register("myavg",myavg)
    
    ds.groupBy("id").agg(myavg($"temperature")).show()
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    实现方式

    • 方式一:实现自定义Aggregator的方式,Spark3.0中是官方推荐的实现方式

    自定义Aggregator的方式需要继承实现Aggregator[IN,BUF,OUT]类。

    * IN 输入到聚合运算中的数据类型
    * BUF 聚合缓冲区的数据类型
    * OUT 结果输出数据类型
    
    • 1
    • 2
    • 3

    弱类型的UDAF的实现

    case class AggBuffer(var count:Long,var sum:Double)
    
    class MyAggregator extends Aggregator[Double,AggBuffer,Double]{
    
      //初始化buffer
      override def zero: AggBuffer = {
        AggBuffer(0,0)
      }
    
      //reduce操作
      override def reduce(b: AggBuffer, a: Double): AggBuffer = {
        b.count+=1
        b.sum+=a
        b
      }
    
      //合并多个buffer
      override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
        b1.sum+= b2.sum
        b1.count += b2.count
        b1
      }
    
      //取结果输出
      override def finish(reduction: AggBuffer): Double = {
        reduction.sum/reduction.count
      }
    
      //buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
      override def bufferEncoder: Encoder[AggBuffer] = Encoders.product
    
      //输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    package com.hjt.yxh.hw.sparksql
    
    import com.hjt.yxh.hw.transmate.SensorReading
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.functions.{avg, udaf}
    
    
    case class AggBuffer(var count:Long,var sum:Double)
    
    class MyAggregator extends Aggregator[Double,AggBuffer,Double]{
    
      //初始化buffer
      override def zero: AggBuffer = {
        AggBuffer(0,0)
      }
    
      //reduce操作
      override def reduce(b: AggBuffer, a: Double): AggBuffer = {
        b.count+=1
        b.sum+=a
        b
      }
    
      //合并多个buffer
      override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
        b1.sum+= b2.sum
        b1.count += b2.count
        b1
      }
    
      //取结果输出
      override def finish(reduction: AggBuffer): Double = {
        reduction.sum/reduction.count
      }
    
      //buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
      override def bufferEncoder: Encoder[AggBuffer] = Encoders.product
    
      //输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    
    object UDAFApp {
    
      def main(args: Array[String]): Unit = {
        val conf:SparkConf = new SparkConf()
        conf.setMaster("local[*]")
          .setAppName("UDAFApp")
    
        val sparkSession:SparkSession = SparkSession.builder()
          .config(conf)
          .config("spark.sql.legacy.charVarcharAsString",true)
          .getOrCreate()
        import sparkSession.implicits._
    
        val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
    
        val df = sparkSession.read.format("CSV")
          .schema("id VARCHAR(32), timestamp BIGINT,temperature Double")
          .load(inpath)
    
        val ds:Dataset[SensorReading] = df.as[SensorReading]
        ds.createOrReplaceTempView("sensor")
    
        ds.show()
        val myavg = udaf(new MyAggregator)
        sparkSession.udf.register("myavg",myavg)
    
        ds.groupBy("id").agg(myavg($"temperature")).show()
    
        sparkSession.stop()
    
      }
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 强类型的UDAF实现示例
    package com.hjt.yxh.hw.sparksql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql._
    import org.apache.spark.sql.expressions.Aggregator
    
    
    case class AggBuffer(var count:Long,var sum:Double)
    case class SensorReading(id:String,timestamp:BigInt,temperature:Double)
    
    class MyAggregator extends Aggregator[SensorReading,AggBuffer,Double]{
    
      //初始化buffer
      override def zero: AggBuffer = {
        AggBuffer(0,0)
      }
    
      //reduce操作
      override def reduce(b: AggBuffer, a: SensorReading): AggBuffer = {
        b.count+=1
        b.sum+=a.temperature
        b
      }
    
      //合并多个buffer
      override def merge(b1: AggBuffer, b2: AggBuffer): AggBuffer = {
        b1.sum+= b2.sum
        b1.count += b2.count
        b1
      }
    
      //取结果输出
      override def finish(reduction: AggBuffer): Double = {
        reduction.sum/reduction.count
      }
    
      //buffer的序列化器(对于自定义的case类是固定写法,Encoders.product)
      override def bufferEncoder: Encoder[AggBuffer] = Encoders.product
    
      //输出的序列化器(对于自定义的case类是固定写法,Encoders.product)
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    
    object UDAFApp {
    
      def main(args: Array[String]): Unit = {
        val conf:SparkConf = new SparkConf()
        conf.setMaster("local[*]")
          .setAppName("UDAFApp")
    
        val sparkSession:SparkSession = SparkSession.builder()
          .config(conf)
          .config("spark.sql.legacy.charVarcharAsString",true)
          .getOrCreate()
        import sparkSession.implicits._
    
        val inpath = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
    
        val df = sparkSession.read.format("CSV")
          .schema("id VARCHAR(32), timestamp BIGINT,temperature Double")
          .load(inpath)
    
        val ds:Dataset[SensorReading] = df.as[SensorReading]
        ds.createOrReplaceTempView("sensor")
    
        val myavg = new MyAggregator().toColumn.name("avg_temperature")
        sparkSession.udf.register("myavg",functions.udaf(new MyAggregator()))
    
        val ret = ds.select(myavg)
    
        sparkSession.stop()
    
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 方式二 继承UserDefinedAggregateFunction的方式(在新的版本中已移除)
    class MyAvg extends UserDefinedAggregateFunction{
    
      //定义输入的数据结构
      override def inputSchema: StructType = new StructType()
        .add("temperature","double")
    
    
      //定义缓冲区的数据结构
      override def bufferSchema: StructType = new StructType()
        .add("count",LongType)
        .add("sum",DoubleType)
    
      //定义输出的数据类型
      override def dataType: DataType = DataTypes.DoubleType
    
    
      override def deterministic: Boolean = true
    
      //初始化缓冲区
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0,0L)
        buffer.update(1,0.0)
      }
    
      //做更新缓冲区操作,相当用reduce
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        var count = buffer.getLong(0)
        var sum = buffer.getDouble(1)
        count+= 1
        sum+= input.getDouble(0)
        buffer.update(0,count)
        buffer.update(1,sum)
      }
      //合并两个缓存去
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
        buffer1.update(1,buffer1.getDouble(1)+buffer2.getDouble(1))
      }
    
      //计算结果,并输出
      override def evaluate(buffer: Row): Any = {
        buffer.getDouble(1)/buffer.getLong(0)
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    UDTF 用户自定义表函数(一对多)

    说明:

    SparkSQL本身其实是没有提供UDTF函数功能的,需要启用Hive支持的方式才能使用。

    实现

    class MySplit extends GenericUDTF {
    
      override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
        //判断传入的参数是否只有一个
        if (argOIs.length != 1) {
          throw new UDFArgumentException("有且只能有一个参数")
        }
        //判断参数类型
        if (argOIs(0).getCategory != ObjectInspector.Category.PRIMITIVE) {
          throw new UDFArgumentException("参数类型不匹配")
        }
        val fieldNames = new util.ArrayList[String]
        val fieldOIs = new util.ArrayList[ObjectInspector]
        fieldNames.add("type")
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
        ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
      }
        
      override def process(objects: Array[AnyRef]): Unit = {
        val data_list:Array[String] = objects(0).toString.split("_")
        //遍历集合
        for(item <- data_list){
          val temp = new Array[String](1)
          temp(0)=item
          forward(temp)
        }
      }
    
      override def close(): Unit = {}
    }
    
    
    object UDTFApp {
    
      def main(args: Array[String]): Unit = {
        val conf:SparkConf = new SparkConf()
        conf.setMaster("local[*]").setAppName("UDTFTest")
    
        val sparkSession:SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()
        import sparkSession.implicits._
    
        val inpath:String = "D:\\javaworkspace\\BigData\\Spark\\SparkApp\\src\\main\\resources\\sensor.txt"
        val df1 = sparkSession.read.format("csv")
          .schema("id String,timestamp Bigint,temperature double")
          .load(inpath)
        
        //注册表
        df1.createOrReplaceTempView("sensor")
        
        //创建函数,使用类名的方式
        sparkSession.sql("create TEMPORARY function mySplit as 'com.hjt.yxh.hw.sql.MySplit'")
    
        sparkSession.sql("select mySplit(id) from sensor").show()
    
        sparkSession.stop()
      }
    
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
  • 相关阅读:
    【Node】cookie、sessionStorage、localStorage 与 身份认证
    笔试强训2
    docker 简洁版安装kafka做测试
    c语言-实用调试技巧
    电机学 基础概念 野火电机第四章 电机分类
    Unix运维_Tcsh脚本_编译安装OpenSSL-1.1.1g
    编译原理—词法分析、构建DFA、上下文无关文法、LL(1)分析、提取正规式
    POSIX 真的不适合对象存储吗?
    21天经典算法之直接选择排序
    程序员装逼指南(2022年版本)
  • 原文地址:https://blog.csdn.net/wangzhongyudie/article/details/126107509