• MySQL建表语句生成Golang代码


    1. 背景

    对于后台开发新的需求时,一般会先进行各种表的设计,写各个表的建表语句

    然后根据建立的表,写对应的model代码、基础的增删改查代码(基础的增删改查服务可以划入DAO(Data Access Object)层)。

    model代码都有一些固定的格式,可以通过解析SQL建表语句,来自动生成model代码,

    对于不同的表,基础的增删改查代码大概率只是换了个表名或者数据库,因此也可以自动生成。

    通过自动生成代码,减少重复工作,提示开发效率。

     

    2. 整体介绍

    目录结构如下,具体代码建Github sql2code

    .
    ├── README.md
    ├── code2file
    │   └── code2file.go
    ├── go.mod
    ├── go.sum
    ├── main.go
    ├── sql2code_tpl
    │   ├── sql2code_tpl.go
    │   ├── sql2dao_tpl.txt
    │   └── sql2model_tpl.txt
    ├── sql2dao
    │   ├── sql2dao.go
    │   └── sql2dao_test.go
    ├── sql2model
    │   ├── sql2model.go
    │   ├── sql2model_test.go
    │   └── tidb_types.go
    ├── test
    │   └── t_student.sql
    └── util
    └── util_strings
    └── util_strings.go
    7 directories, 15 files

    sql2code_tpl: 主要是model和dao的模版代码

    sql2model:MySQL建表语句到Model代码的主要处理流程

    sql2dao:MySQL建表语句到Dao代码到主要处理流程

    有误go get TiDB的types文件夹会出现各种冲突的错误,因此只拷贝的需要的部分代码到sql2model/tidb_types.go

     

    3. sql到model

    3.1 基本思路

    使用SQL解析器获得表名及每一列信息,使用模版生成model代码。

    SQL解析器使用的是TiDB parser

    3.2 model模版

    package {{.PackageName}}
    // {{.Comment}}
    type {{.ModelName}} struct {
    {{- range .Rows}}
    {{.Name}} {{.GoType}} {{.Tags}} // {{.Comment}}
    {{- end}}
    }
    func ({{.ModelName}}) TableName() string {
    return "{{.OriginTblName}}"
    }

    3.3 部分实现代码

    使用SQL解析器解析建表语句,获得表名,及每一列的列名,注释等信息,主要代码如下

    func SQLParse(sql, tablePrefix string) (*ModelTable, error) {
    cts, err := parseCreateTableStmt(sql)
    if err != nil {
    log.Printf("parseCreateTableStmt fail,err:%v", err)
    return nil, err
    }
    mt := &ModelTable{}
    primaryKey := ""
    // table name
    tblName := TableNamePrefixCut(cts.Table.Name.L, tablePrefix)
    mt.ModelName = TableName2ModelName(tblName)
    mt.OriginTblName = cts.Table.Name.L
    // primary
    for _, ctt := range cts.Constraints {
    // only contain one primary key
    if ctt.Tp == ast.ConstraintPrimaryKey {
    if len(ctt.Keys) >= 0 {
    primaryKey = ctt.Keys[0].Column.Name.L
    }
    break
    }
    }
    // comment
    for _, op := range cts.Options {
    if op.Tp == ast.TableOptionComment {
    mt.Comment = op.StrValue
    }
    }
    modelRows := make([]ModelRow, 0, len(cts.Cols))
    for _, col := range cts.Cols {
    nameLow := col.Name.Name.L
    modelRow := ModelRow{
    Name: util_strings.ToCamel(col.Name.Name.L), // 需要去除下划线转驼峰
    }
    //fmt.Printf("col: %+v %+v %v %v\n", col.Name, col.Tp, HasUnsignedFlag(col.Tp.GetFlag()), col.Tp.GetType())
    modelRow.GoType = sqlType2GoType(col.Tp.GetType(), col.Tp.GetFlag())
    for _, op := range col.Options {
    if op.Tp == ast.ColumnOptionComment {
    exprVal, ok := op.Expr.(*test_driver.ValueExpr)
    if !ok {
    fmt.Println("op.Expr.(*test_driver.ValueExpr) fail.")
    continue
    }
    modelRow.Comment = exprVal.Datum.GetString()
    break
    }
    }
    if primaryKey == col.Name.Name.L {
    modelRow.Tags = fmt.Sprintf("`gorm:\"column:%v; primary_key\" json:\"%v\"`", nameLow, nameLow)
    } else {
    modelRow.Tags = fmt.Sprintf("`gorm:\"column:%v;\" json:\"%v\"`", nameLow, nameLow)
    }
    //fmt.Println(modelRow.Tags)
    modelRows = append(modelRows, modelRow)
    }
    mt.Rows = modelRows
    return mt, nil
    }

    3. sql到dao

    2.1 基本思路

    • 获得MySQL建表语句的表名信息
    • 使用模版生成CRUD代码

    2.2 模版代码

    查看代码
     package {{.PackageName}}
    {{template "addTemplate" .}}
    {{template "deleteTemplate" .}}
    {{template "updateTemplate" .}}
    {{template "getMultiTemplate" .}}
    {{template "getCountTemplate" .}}
    {{template "getOneTemplate" .}}
    {{define "addTemplate"}}
    func Add{{.ModelName}}(ctx context.Context,obj *{{.ModelPackage}}.{{.ModelName}}, whereMap map[string]interface{}) (error, int64) {
    if whereMap != nil {
    err, existObj := GetOne{{.ModelName}}(ctx, whereMap)
    if err != nil {
    log.Printf("[Add{{.ModelName}}]GetOne{{.ModelName}} fail, err:%v, obj:%v", err, obj)
    return err, int64(0)
    }
    if existObj != nil && existObj.AddTime > int64(0) {
    logs.CtxInfo(ctx, "[Add{{.ModelName}}] {{.ModelName}} exist, existsObj:%v", existObj)
    return nil, existObj.ID
    }
    }
    if obj.AddTime <= 0 {
    obj.AddTime = util_datetime.CurrentMS()
    }
    if obj.UpdateTime <= 0 {
    obj.UpdateTime = util_datetime.CurrentMS()
    }
    res := {{.DBConect}}.Create(obj)
    if res.Error != nil {
    log.Printf("[Add{{.ModelName}}]Add{{.ModelName}} fail, err:%v, obj:%v", res.Error, obj)
    return res.Error, int64(0)
    }
    return res.Error, obj.ID
    }
    {{end}}
    {{define "deleteTemplate"}}
    func Delete{{.ModelName}}(ctx context.Context,whereMap map[string]interface{}) (error, int64) {
    query := db.WhereQuery({{.DBConect}}, whereMap)
    res := query.Delete(&{{.ModelPackage}}.{{.ModelName}}{})
    if res.Error != nil {
    log.Printf("Delete{{.ModelName}} failed, err:%v, whereMap:%v", res.Error, whereMap)
    return res.Error, int64(0)
    }
    rowsAffected := res.RowsAffected
    return nil, rowsAffected
    }
    {{end}}
    {{define "updateTemplate"}}
    func Update{{.ModelName}}(ctx context.Context, whereMap map[string]interface{}, setMap map[string]interface{}) (error, int64) {
    obj := &{{.ModelPackage}}.{{.ModelName}}{}
    query := {{.DBConect}}.Model(obj)
    query = db.WhereQuery(query, whereMap)
    if updateTime, ok := setMap["update_time"]; !ok || updateTime.(int64) <= 0 {
    setMap["update_time"] = util_datetime.CurrentMS()
    }
    res := query.Updates(setMap)
    if res.Error != nil {
    log.Printf("[Update{{.ModelName}}]Update{{.ModelName}} fail, err:%v, whereMap:%v, setMap:%v", res.Error, whereMap, setMap)
    return res.Error, int64(0)
    }
    rowsAffected := res.RowsAffected
    return nil, rowsAffected
    }
    {{end}}
    {{define "getMultiTemplate"}}
    func GetMulti{{.ModelName}}s(ctx context.Context,whereMap map[string]interface{}, offset, limit int64, orderBy, groupby, fields string) (error, []*{{.ModelPackage}}.{{.ModelName}}) {
    objs := []*{{.ModelPackage}}.{{.ModelName}}{}
    query := db.WhereQuery({{.DBConect}}, whereMap)
    query = db.OrderByQuery(query, orderBy)
    query = db.FieldsQuery(query, fields)
    query = db.GroupByQuery(query, groupby)
    query = db.LimitQuery(query, offset, limit)
    res := query.Find(&objs)
    if res.Error != nil {
    if res.Error.Error() == "record not found" {
    return nil, nil
    }
    log.Printf("[GetMulti{{.ModelName}}s]GetMulti{{.ModelName}}s fail, err:%v", res.Error)
    }
    return res.Error, objs
    }
    {{end}}
    {{define "getCountTemplate"}}
    func GetMulti{{.ModelName}}sCount(ctx context.Context,whereMap map[string]interface{}) (error, int64) {
    cnt := int64(0)
    query := db.WhereQuery({{.DBConect}}, whereMap)
    res := query.Model(&{{.ModelPackage}}.{{.ModelName}}{}).Count(&cnt)
    if res.Error != nil {
    if res.Error.Error() == "record not found" {
    return nil, cnt
    }
    log.Printf("GetMulti{{.ModelName}}sCount fail, err:%v", res.Error)
    }
    return res.Error, cnt
    }
    {{end}}
    {{define "getOneTemplate"}}
    func GetOne{{.ModelName}}(ctx context.Context,whereMap map[string]interface{},fields string)(error, *{{.ModelPackage}}.{{.ModelName}}) {
    err, objs := GetMulti{{.ModelName}}s(ctx, whereMap, 0, 1, "", "", fields)
    if err != nil {
    return err, nil
    }
    if len(objs) >= 1 {
    return nil, objs[0]
    }
    return nil, nil
    }
    {{end}}

    2.2 部分实现代码

    主要是根据表名获取对应的model名称、包名等,再利用模版生成代码。

    func SQL2Dao(sql string, tablePrefix, packagePrefix, dbCon string) (string, error) {
    tblName, err := sql2model.TableNameGetFromSQL(sql, tablePrefix)
    if err != nil {
    return "", err
    }
    modelPackage := sql2model.ModelPackageGet(tblName, tablePrefix, packagePrefix)
    packageName := DaoPackageNameGet(tblName, tablePrefix, packagePrefix)
    df := DaoFile{
    PackageName: packageName,
    ModelPackage: modelPackage,
    ModelName: sql2model.TableName2ModelName(tblName),
    DBConect: dbCon,
    }
    return daoFileGen(df)
    }

    4. 使用方式

    Usage of this program:
    -dbcon string
    db connect name
    -if string
    File path of the SQL statement that creates the table
    -op int
    1:gen model code 2:gen dao code 3:both (default 1)
    -pp string
    package prefix add for go file
    -sql string
    SQL statement to create table
    -tp string
    table prefix of table name to cut

    5. 使用实例

    $ go run main.go -if=./test/t_student.sql -dbcon=UserDB -tp="t_" -pp=user -op=3
    model code have been write to ./output/user_student.go
    model code have been write to ./output/user_student_service.go

    6. 完整代码

    Github sql2code

    5. 参考

    1. TiDB parser quickstart

     

  • 相关阅读:
    客户心声 | 四川省人社厅杨玉成一行充分肯定桂溪街道劳动保障工作信息化建设平台
    似然和概率
    【CF1635F】 Closest Pair 题解
    AI换脸之Faceswap技术原理与实践
    Ubuntu 20.04使用源码安装nginx 1.14.0
    PAM从入门到精通(二十三)
    Linux环境python脚本后台运行
    使用Filebeat+Kafka+Logstash+Elasticsearch构建日志分析系统
    【frontend】如何横向排列?
    数据挖掘面试经总结【私人版,仅供参考】
  • 原文地址:https://www.cnblogs.com/amos01/p/16659999.html