• Spark【Spark SQL(四)UDF函数和UDAF函数】


    UDF 函数

            UDF 是我们用户可以自定义的函数,我们通过SparkSession对象来调用 udf 的 register(name:String,func(A1,A2,A3...)) 方法来注册一个我们自定义的函数。其中,name 是我们自定义的函数名称,func 是我们自定义的函数,它可以有很多个参数。

            通过 UDF 函数,我们可以针对某一列数据或者某单元格数据进行针对的处理。

    案例 1

    定义一个函数,给 Andy 的 name 字段的值前 + "Name: "。

    1. def main(args: Array[String]): Unit = {
    2. val conf = new SparkConf()
    3. conf.setAppName("spark sql udf")
    4. .setMaster("local[*]")
    5. val spark = SparkSession.builder().config(conf).getOrCreate()
    6. import spark.implicits._
    7. val df = spark.read.json("data/sql/people.json")
    8. df.createOrReplaceTempView("people")
    9. spark.udf.register("prefixName",(name:String)=>{
    10. if (name.equals("Andy"))
    11. "Name: " + name
    12. else
    13. name
    14. })
    15. spark.sql("select prefixName(name) as name,age,sex from people").show()
    16. spark.stop()
    17. }

            这里我们定义了一个自定义的 UDF 函数:prefixName,它会判断name字段的值是否为 "Andy",如果是,就会在她的值前+"Name: "。

    运行结果:

    1. +----------+---+---+
    2. | name|age|sex|
    3. +----------+---+---+
    4. | Michael| 30| 男|
    5. |Name: Andy| 19| 女|
    6. | Justin| 19| 男|
    7. |Bernadette| 20| 女|
    8. | Gretchen| 23| 女|
    9. | David| 27| 男|
    10. | Joseph| 33| 女|
    11. | Trish| 27| 女|
    12. | Alex| 33| 女|
    13. | Ben| 25| 男|
    14. +----------+---+---+

    UDAF 函数

            强类型的DataSet和弱类型的DataFrame都提供了相关聚合函数,如count、countDistinct、avg、max、min。

            UDAF 也就是我们用户的自定义聚合函数。聚合函数就比如 avg、sum这种函数,需要先把所有数据放到一起(缓冲区),再进行统一处理的一个函数。

            实现 UDAF 函数需要有我们自定义的聚合函数的类(主要任务就是计算),我们可以继承 UserDefinedAggregateFunction,并实现里面的八种方法,来实现弱类型的聚合函数。(Spark3.0之后就不推荐使用了,更加推荐强类型的聚合函数)

            我们可以继承Aggregator来实现强类型的聚合函数。

    案例1 - 平均年龄

    case 类可以直接构建对象,不需要new,因为样例类可以自动生成它的伴生对象和apply方法。

    弱类型实现

    1. import org.apache.spark.SparkConf
    2. import org.apache.spark.sql.{Row, SparkSession}
    3. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    4. import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}
    5. /**
    6. * 弱类型
    7. */
    8. object UDAFTest01 {
    9. def main(args: Array[String]): Unit = {
    10. val conf = new SparkConf()
    11. conf.setAppName("spark sql udaf")
    12. .setMaster("local[*]")
    13. val spark = SparkSession.builder().config(conf).getOrCreate()
    14. import spark.implicits._
    15. val df = spark.read.json("data/sql/people.json")
    16. df.createOrReplaceTempView("people")
    17. spark.udf.register("avgAge",new MyAvgUDAF())
    18. spark.sql("select avgAge(age) from people").show()
    19. spark.stop()
    20. }
    21. }
    22. class MyAvgUDAF extends UserDefinedAggregateFunction{
    23. // 输入数据的结构 IN
    24. override def inputSchema: StructType = {
    25. StructType(
    26. Array(StructField("age",LongType))
    27. )}
    28. // 缓冲区数据的结构 BUFFER
    29. override def bufferSchema: StructType = {
    30. StructType(
    31. Array(
    32. StructField("total",LongType),
    33. StructField("count",LongType)
    34. )
    35. )}
    36. // 函数计算结果的数据类型 OUT
    37. override def dataType: DataType = LongType
    38. // 函数的稳定性 (传入相同的参数结果是否相同)
    39. override def deterministic: Boolean = true
    40. // 缓冲区初始化
    41. override def initialize(buffer: MutableAggregationBuffer): Unit = {
    42. //这两种写法都一样
    43. // buffer(0) = 0L
    44. // buffer(1) = 0L
    45. //第二种方法
    46. buffer.update(0,0L) //total 给缓冲区的第0个数据结构-total-初始化赋值0L
    47. buffer.update(1,0L) //count 给缓冲区的第1个数据结构-count-初始化赋值0L
    48. }
    49. // 数据过来之后 如何更新缓冲区
    50. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    51. // 第一个参数代表缓冲区的第i个数据结构 0代表total 1代表count
    52. // 第二个参数是对第一个参数的数据结构进行重新赋值
    53. // buffer.getLong(0)是取出缓冲区第0个值-也就是total的值,给它+上输入的值中的第0个值(因为我们输入结构只有一个就是age:Long)
    54. buffer.update(0,buffer.getLong(0)+input.getLong(0))
    55. buffer.update(1,buffer.getLong(1)+1) //count 每次数据过来+1
    56. }
    57. // 多个缓冲区数据合并
    58. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    59. buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
    60. buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
    61. }
    62. // 计算结果操作
    63. override def evaluate(buffer: Row): Any = {
    64. buffer.getLong(0)/buffer.getLong(1)
    65. }
    66. }

    运行结果:

    1. +-----------+
    2. |avgage(age)|
    3. +-----------+
    4. | 25|
    5. +-----------+

     

    强类型实现

    1. import org.apache.spark.SparkConf
    2. import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession, functions}
    3. import org.apache.spark.sql.expressions.Aggregator
    4. /**
    5. * 强类型
    6. */
    7. object UDAFTest02 {
    8. def main(args: Array[String]): Unit = {
    9. val conf = new SparkConf()
    10. conf.setAppName("spark sql udaf")
    11. .setMaster("local[*]")
    12. val spark = SparkSession.builder().config(conf).getOrCreate()
    13. import spark.implicits._
    14. val df = spark.read.json("data/sql/people.json")
    15. df.createOrReplaceTempView("people")
    16. spark.udf.register("avgAge",functions.udaf(new MyAvg_UDAF()))
    17. spark.sql("select avgAge(age) from people").show()
    18. spark.stop()
    19. }
    20. }
    21. /**
    22. * 自定义聚合函数类:
    23. * 1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型:
    24. * IN : 输入数据类型 Long
    25. * BUF : 缓冲区数据类型
    26. * OUT : 输出数据类型 Long
    27. * 2.重写方法
    28. */
    29. //样例类中的参数默认是 val 所以这里必须指定为var
    30. case class Buff(var total: Long,var count: Long)
    31. class MyAvg_UDAF extends Aggregator[Long,Buff,Long]{
    32. // zero: Buff zero代表这个方法是用来初始值(0值)
    33. // Buff是我们的case类 也就是说明这里是用来给 缓冲区进行初始化
    34. override def zero: Buff = {
    35. Buff(0L,0L)
    36. }
    37. // 根据输入数据更新缓冲区 要求返回-Buff
    38. override def reduce(buff: Buff, in: Long): Buff = {
    39. buff.total += in
    40. buff.count += 1
    41. buff
    42. }
    43. // 合并缓冲区 同样返回buff1
    44. override def merge(buff1: Buff, buff2: Buff): Buff = {
    45. buff1.total += buff2.total
    46. buff1.count += buff2.count
    47. buff1
    48. }
    49. // 计算结果
    50. override def finish(buff: Buff): Long = {
    51. buff.total/buff.count
    52. }
    53. // 网络传输需要序列化 缓冲区的编码操作 -编码
    54. override def bufferEncoder: Encoder[Buff] = Encoders.product
    55. // 输出的编码操作 -解码
    56. override def outputEncoder: Encoder[Long] = Encoders.scalaLong
    57. }

    运行结果:

    1. +-----------+
    2. |avgage(age)|
    3. +-----------+
    4. | 25|
    5. +-----------+

     

    早期UDAF强类型聚合函数

    SQL:结构化数据查询 & DSL:面向对象查询(有对象有方法,与类型相关,所以通过DSL语句结合起来使用)

    早期的UDAF强类型聚合函数使用DSL操作。

    定义一个case类对应数据类型,然后通过as[对象]方法将DataFrame转为DataSet类型,然后将我们的UDAF聚合类转为列对象。

    1. import org.apache.spark.SparkConf
    2. import org.apache.spark.sql.{Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn, functions}
    3. import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
    4. import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}
    5. /**
    6. * 早期的UDAF强类型聚合函数使用DSL操作
    7. */
    8. object UDAFTest03 {
    9. def main(args: Array[String]): Unit = {
    10. val conf = new SparkConf()
    11. conf.setAppName("spark sql udaf")
    12. .setMaster("local[*]")
    13. val spark = SparkSession.builder().config(conf).getOrCreate()
    14. import spark.implicits._
    15. val df = spark.read.json("data/sql/people.json")
    16. val ds: Dataset[User] = df.as[User]
    17. // 将UDAF强类型聚合函数转为查询的类对象
    18. val udafCol: TypedColumn[User, Long] = new OldAvg_UDAF().toColumn
    19. ds.select(udafCol).show()
    20. spark.stop()
    21. }
    22. }
    23. /**
    24. * 自定义聚合函数类:
    25. * 1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型:
    26. * IN : 输入数据类型 User
    27. * BUF : 缓冲区数据类型
    28. * OUT : 输出数据类型 Long
    29. * 2.重写方法
    30. */
    31. //样例类中的参数默认是 val 所以这里必须指定为var
    32. case class User(name: String,age: Long,sex: String)
    33. case class Buff(var total: Long,var count: Long)
    34. class OldAvg_UDAF extends Aggregator[User,Buff,Long]{
    35. // zero: Buff zero代表这个方法是用来初始值(0值)
    36. // Buff是我们的case类 也就是说明这里是用来给 缓冲区进行初始化
    37. override def zero: Buff = {
    38. Buff(0L,0L)
    39. }
    40. // 根据输入数据更新缓冲区 要求返回-Buff
    41. override def reduce(buff: Buff, in: User): Buff = {
    42. buff.total += in.age
    43. buff.count += 1
    44. buff
    45. }
    46. // 合并缓冲区 同样返回buff1
    47. override def merge(buff1: Buff, buff2: Buff): Buff = {
    48. buff1.total += buff2.total
    49. buff1.count += buff2.count
    50. buff1
    51. }
    52. // 计算结果
    53. override def finish(buff: Buff): Long = {
    54. buff.total/buff.count
    55. }
    56. // 网络传输需要序列化 缓冲区的编码操作 -编码
    57. override def bufferEncoder: Encoder[Buff] = Encoders.product
    58. // 输出的编码操作 -解码
    59. override def outputEncoder: Encoder[Long] = Encoders.scalaLong
    60. }

    运行结果:

    1. +------------------------------------------+
    2. |OldAvg_UDAF(com.study.spark.core.sql.User)|
    3. +------------------------------------------+
    4. | 25|
    5. +------------------------------------------+

  • 相关阅读:
    Anaconda常用命令
    Docker版部署RocketMQ开启ACL验证
    利用ETLCloud自动化流程实现业务系统数据快速同步至数仓
    京东小程序:无代码开发实现API集成,连接电商平台、CRM和客服系统
    有关遗传算法最新发展的4篇论文推荐
    <Python>PyQt5,多窗口之间参数传递和函数调用
    计算机毕业设计 高校课程评价系统的设计与实现 Java实战项目 附源码+文档+视频讲解
    简单讲解Android Fragment(四)
    【ubuntu】搭建lamp架构
    mnist数据集
  • 原文地址:https://blog.csdn.net/m0_64261982/article/details/132864647