• SparkSql批量插入或更新,保存数据到Mysql中


    在sparksql 中,保存数据到数据,只有 Append , Overwrite , ErrorIfExists, Ignore 四种模式,不满足项目需求 ,此处大概说一下我们需求,当业务库有数据发生变化,需要更新、插入、删除数仓中ods层的数据,因此需要改造源码

    现依据 spark save 源码,进行进一步的改造, 批量保存数据,存在则更新 不存在 则插入

    1. import com.sun.corba.se.impl.activation.ServerMain.logError
    2. import org.apache.spark.SparkContext
    3. import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.getCommonJDBCType
    4. import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
    5. import org.apache.spark.sql.types._
    6. import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
    7. import java.sql.{Connection, DriverManager, PreparedStatement}
    8. import java.util.Properties
    9. object TestInsertOrUpdateMysql {
    10. val url: String = "jdbc:mysql://192.168.1.1:3306/test?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&autoReconnect=true&failOverReadOnly=false"
    11. val driver: String = "com.mysql.jdbc.Driver"
    12. val user: String = "123"
    13. val password: String = "123"
    14. val sql: String = "select * from testserver "
    15. val table: String = "testinsertorupdate"
    16. def main(args: Array[String]): Unit = {
    17. val spark = SparkSession.builder()
    18. .master("local[*]")
    19. .appName("testSqlServer").getOrCreate()
    20. val dbtable = "(" + sql + ") AS Temp"
    21. val jdbcDF = spark.read
    22. .format("jdbc")
    23. .option("driver", driver)
    24. .option("user", user)
    25. .option("password", password)
    26. .option("url", url)
    27. .option("dbtable", dbtable)
    28. .load()
    29. jdbcDF.show()
    30. //普通写入数据库
    31. //commonWrite(jdbcDF)
    32. //saveorupdate
    33. insertOrUpdateToMysql("id", jdbcDF, spark)
    34. println("======================程序结束======================")
    35. }

    1.先看一下普通的插入怎么写的

    1. def commonWrite(jdbcDF: DataFrame): Unit = {
    2. val properties = new Properties()
    3. properties.put("user", user)
    4. properties.put("password", password)
    5. properties.put("driver", driver)
    6. jdbcDF.write.mode(SaveMode.Append).jdbc(url, table, properties)
    7. }

    这种方式比较局限,只能做一些简单的插入(追加或覆盖等SaveMode.Append)

    那么新的写法是什么呢,首先写出mysql的更新或插入的语法规则:

    1. INSERT INTO t_name ( c1, c2, c3 )
    2. VALUES
    3. ( 1, '1', '1')
    4. ON DUPLICATE KEY UPDATE
    5. c2 = '2';

    需要注意的是一定要有主键,没主键没法更新;

    2.看一下insertorupdate的写法

    1. //写入数据库,批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码
    2. // 规则如下:
    3. //没有关键主键字段即为插入,有即为更新
    4. def insertOrUpdateToMysql(primaryKey: String, jdbcDF: DataFrame, spark: SparkSession): Unit = {
    5. val sc: SparkContext = spark.sparkContext
    6. spark.conf.set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
    7. //1.加载驱动程序
    8. Class.forName(driver);
    9. //2. 获得数据库连接
    10. val conn: Connection = DriverManager.getConnection(url, user, password);
    11. val tableSchema = jdbcDF.schema
    12. val columns = tableSchema.fields.map(x => x.name).mkString(",")
    13. val placeholders = tableSchema.fields.map(_ => "?").mkString(",")
    14. val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on duplicate key update "
    15. val update = tableSchema.fields.map(x =>
    16. x.name.toString + "=?"
    17. ).mkString(",")
    18. //ON DUPLICATE KEY UPDATE
    19. //on conflict($primaryKey) do update set
    20. val realsql = sql.concat(update)
    21. conn.setAutoCommit(false)
    22. val dialect = JdbcDialects.get(conn.getMetaData.getURL)
    23. val broad_ps = sc.broadcast(conn.prepareStatement(realsql))
    24. val numFields = tableSchema.fields.length * 2
    25. //调用spark中自带的函数,获取属性字段与字段类型
    26. val nullTypes = tableSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    27. val setters = tableSchema.fields.map(f => makeSetter(conn, f.dataType))
    28. var rowCount = 0
    29. val batchSize = 2000
    30. val updateindex = numFields / 2
    31. try {
    32. jdbcDF.foreachPartition(iterator => {
    33. //遍历批量提交
    34. val ps = broad_ps.value
    35. try {
    36. while (iterator.hasNext) {
    37. val row = iterator.next()
    38. var i = 0
    39. while (i < numFields) {
    40. i < updateindex match {
    41. case true => {
    42. if (row.isNullAt(i)) {
    43. ps.setNull(i + 1, nullTypes(i))
    44. } else {
    45. setters(i).apply(ps, row, i, 0)
    46. }
    47. }
    48. case false => {
    49. if (row.isNullAt(i - updateindex)) {
    50. ps.setNull(i + 1, nullTypes(i - updateindex))
    51. } else {
    52. setters(i - updateindex).apply(ps, row, i, updateindex)
    53. }
    54. }
    55. }
    56. i = i + 1
    57. }
    58. ps.addBatch()
    59. rowCount += 1
    60. if (rowCount % batchSize == 0) {
    61. ps.executeBatch()
    62. rowCount = 0
    63. }
    64. }
    65. if (rowCount > 0) {
    66. ps.executeBatch()
    67. }
    68. } finally {
    69. ps.close()
    70. }
    71. })
    72. conn.commit()
    73. } catch {
    74. case e: Exception =>
    75. logError("Error in execution of insert. " + e.getMessage)
    76. conn.rollback()
    77. // insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)
    78. } finally {
    79. conn.close()
    80. }
    81. }

    几个源码包

    1. private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    2. dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
    3. throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}"))
    4. }
    5. private type JDBCValueSetter_add = (PreparedStatement, Row, Int, Int) => Unit
    6. private def makeSetter(conn: Connection, dataType: DataType): JDBCValueSetter_add = dataType match {
    7. case IntegerType =>
    8. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    9. stmt.setInt(pos + 1, row.getInt(pos - currentpos))
    10. case LongType =>
    11. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    12. stmt.setLong(pos + 1, row.getLong(pos - currentpos))
    13. case DoubleType =>
    14. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    15. stmt.setDouble(pos + 1, row.getDouble(pos - currentpos))
    16. case FloatType =>
    17. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    18. stmt.setFloat(pos + 1, row.getFloat(pos - currentpos))
    19. case ShortType =>
    20. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    21. stmt.setInt(pos + 1, row.getShort(pos - currentpos))
    22. case ByteType =>
    23. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    24. stmt.setInt(pos + 1, row.getByte(pos - currentpos))
    25. case BooleanType =>
    26. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    27. stmt.setBoolean(pos + 1, row.getBoolean(pos - currentpos))
    28. case StringType =>
    29. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    30. stmt.setString(pos + 1, row.getString(pos - currentpos))
    31. case BinaryType =>
    32. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    33. stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - currentpos))
    34. case TimestampType =>
    35. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    36. stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - currentpos))
    37. case DateType =>
    38. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    39. stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - currentpos))
    40. case t: DecimalType =>
    41. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    42. stmt.setBigDecimal(pos + 1, row.getDecimal(pos - currentpos))
    43. case _ =>
    44. (stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
    45. throw new IllegalArgumentException(
    46. s"Can't translate non-null value for field $pos")
    47. }

    这里面有个属性比较关键我列出来,不加会报错 Exception in thread "main" java.io.NotSerializableException: com.mysql.jdbc.JDBC42PreparedStatement:

    spark.conf.set("spark.serializer","org.apache.spark.serializer.KryoSerializer")

    添加下这个配置 这个第三方序列化 用默认的javaSerializer 不行

    结果如下,这块有一些坑,有需要的朋友,我们可以交流

    附postgreesql的更新或插入语法:

    1. INSERT INTO test_001 ( c1, c2, c3 )
    2. VALUES( ?, ?, ? )
    3. ON conflict ( ID ) DO
    4. UPDATE SET c1=?,c2 = ?,c3 = ?;

    MySQL的on duplicate key update 的使用_厄尔尼诺的夏天的博客-CSDN博客 

    博主qq:907044657,欢迎大家一起交流学习,有问题请指出,转载麻烦注明出处,多谢啦

  • 相关阅读:
    axios拦截器
    UI设计师的发展前景是否超越了平面设计?
    用冒泡排序完成库函数qsort的作用
    可靠的自托管「GitHub 热点速览 v.22.37」
    28 行为型模式-中介者模式
    服务器常用端口号总结
    CSS变量 var()的用法
    【luogu P3295】萌萌哒(并查集)(倍增)
    「笔试刷题」:腐烂的苹果
    Spring Boot访问静态资源
  • 原文地址:https://blog.csdn.net/Alex_81D/article/details/125893271