1. 背景
对于后台开发新的需求时,一般会先进行各种表的设计,写各个表的建表语句
然后根据建立的表,写对应的model代码、基础的增删改查代码(基础的增删改查服务可以划入DAO(Data Access Object)层)。
model代码都有一些固定的格式,可以通过解析SQL建表语句,来自动生成model代码,
对于不同的表,基础的增删改查代码大概率只是换了个表名或者数据库,因此也可以自动生成。
通过自动生成代码,减少重复工作,提示开发效率。
2. 整体介绍
目录结构如下,具体代码建Github sql2code
. ├── README.md ├── code2file │ └── code2file.go ├── go.mod ├── go.sum ├── main.go ├── sql2code_tpl │ ├── sql2code_tpl.go │ ├── sql2dao_tpl.txt │ └── sql2model_tpl.txt ├── sql2dao │ ├── sql2dao.go │ └── sql2dao_test.go ├── sql2model │ ├── sql2model.go │ ├── sql2model_test.go │ └── tidb_types.go ├── test │ └── t_student.sql └── util └── util_strings └── util_strings.go 7 directories, 15 files
sql2code_tpl: 主要是model和dao的模版代码
sql2model:MySQL建表语句到Model代码的主要处理流程
sql2dao:MySQL建表语句到Dao代码到主要处理流程
有误go get TiDB的types文件夹会出现各种冲突的错误,因此只拷贝的需要的部分代码到sql2model/tidb_types.go
3. sql到model
3.1 基本思路
使用SQL解析器获得表名及每一列信息,使用模版生成model代码。
SQL解析器使用的是TiDB parser
3.2 model模版
package {{.PackageName}} // {{.Comment}} type {{.ModelName}} struct { {{- range .Rows}} {{.Name}} {{.GoType}} {{.Tags}} // {{.Comment}} {{- end}} } func ({{.ModelName}}) TableName() string { return "{{.OriginTblName}}" }
3.3 部分实现代码
使用SQL解析器解析建表语句,获得表名,及每一列的列名,注释等信息,主要代码如下
func SQLParse(sql, tablePrefix string) (*ModelTable, error) { cts, err := parseCreateTableStmt(sql) if err != nil { log.Printf("parseCreateTableStmt fail,err:%v", err) return nil, err } mt := &ModelTable{} primaryKey := "" // table name tblName := TableNamePrefixCut(cts.Table.Name.L, tablePrefix) mt.ModelName = TableName2ModelName(tblName) mt.OriginTblName = cts.Table.Name.L // primary for _, ctt := range cts.Constraints { // only contain one primary key if ctt.Tp == ast.ConstraintPrimaryKey { if len(ctt.Keys) >= 0 { primaryKey = ctt.Keys[0].Column.Name.L } break } } // comment for _, op := range cts.Options { if op.Tp == ast.TableOptionComment { mt.Comment = op.StrValue } } modelRows := make([]ModelRow, 0, len(cts.Cols)) for _, col := range cts.Cols { nameLow := col.Name.Name.L modelRow := ModelRow{ Name: util_strings.ToCamel(col.Name.Name.L), // 需要去除下划线转驼峰 } //fmt.Printf("col: %+v %+v %v %v\n", col.Name, col.Tp, HasUnsignedFlag(col.Tp.GetFlag()), col.Tp.GetType()) modelRow.GoType = sqlType2GoType(col.Tp.GetType(), col.Tp.GetFlag()) for _, op := range col.Options { if op.Tp == ast.ColumnOptionComment { exprVal, ok := op.Expr.(*test_driver.ValueExpr) if !ok { fmt.Println("op.Expr.(*test_driver.ValueExpr) fail.") continue } modelRow.Comment = exprVal.Datum.GetString() break } } if primaryKey == col.Name.Name.L { modelRow.Tags = fmt.Sprintf("`gorm:\"column:%v; primary_key\" json:\"%v\"`", nameLow, nameLow) } else { modelRow.Tags = fmt.Sprintf("`gorm:\"column:%v;\" json:\"%v\"`", nameLow, nameLow) } //fmt.Println(modelRow.Tags) modelRows = append(modelRows, modelRow) } mt.Rows = modelRows return mt, nil }
3. sql到dao
2.1 基本思路
- 获得MySQL建表语句的表名信息
- 使用模版生成CRUD代码
2.2 模版代码
查看代码
package {{.PackageName}} {{template "addTemplate" .}} {{template "deleteTemplate" .}} {{template "updateTemplate" .}} {{template "getMultiTemplate" .}} {{template "getCountTemplate" .}} {{template "getOneTemplate" .}} {{define "addTemplate"}} func Add{{.ModelName}}(ctx context.Context,obj *{{.ModelPackage}}.{{.ModelName}}, whereMap map[string]interface{}) (error, int64) { if whereMap != nil { err, existObj := GetOne{{.ModelName}}(ctx, whereMap) if err != nil { log.Printf("[Add{{.ModelName}}]GetOne{{.ModelName}} fail, err:%v, obj:%v", err, obj) return err, int64(0) } if existObj != nil && existObj.AddTime > int64(0) { logs.CtxInfo(ctx, "[Add{{.ModelName}}] {{.ModelName}} exist, existsObj:%v", existObj) return nil, existObj.ID } } if obj.AddTime <= 0 { obj.AddTime = util_datetime.CurrentMS() } if obj.UpdateTime <= 0 { obj.UpdateTime = util_datetime.CurrentMS() } res := {{.DBConect}}.Create(obj) if res.Error != nil { log.Printf("[Add{{.ModelName}}]Add{{.ModelName}} fail, err:%v, obj:%v", res.Error, obj) return res.Error, int64(0) } return res.Error, obj.ID } {{end}} {{define "deleteTemplate"}} func Delete{{.ModelName}}(ctx context.Context,whereMap map[string]interface{}) (error, int64) { query := db.WhereQuery({{.DBConect}}, whereMap) res := query.Delete(&{{.ModelPackage}}.{{.ModelName}}{}) if res.Error != nil { log.Printf("Delete{{.ModelName}} failed, err:%v, whereMap:%v", res.Error, whereMap) return res.Error, int64(0) } rowsAffected := res.RowsAffected return nil, rowsAffected } {{end}} {{define "updateTemplate"}} func Update{{.ModelName}}(ctx context.Context, whereMap map[string]interface{}, setMap map[string]interface{}) (error, int64) { obj := &{{.ModelPackage}}.{{.ModelName}}{} query := {{.DBConect}}.Model(obj) query = db.WhereQuery(query, whereMap) if updateTime, ok := setMap["update_time"]; !ok || updateTime.(int64) <= 0 { setMap["update_time"] = util_datetime.CurrentMS() } res := query.Updates(setMap) if res.Error != nil { log.Printf("[Update{{.ModelName}}]Update{{.ModelName}} fail, err:%v, whereMap:%v, setMap:%v", res.Error, whereMap, setMap) return res.Error, int64(0) } rowsAffected := res.RowsAffected return nil, rowsAffected } {{end}} {{define "getMultiTemplate"}} func GetMulti{{.ModelName}}s(ctx context.Context,whereMap map[string]interface{}, offset, limit int64, orderBy, groupby, fields string) (error, []*{{.ModelPackage}}.{{.ModelName}}) { objs := []*{{.ModelPackage}}.{{.ModelName}}{} query := db.WhereQuery({{.DBConect}}, whereMap) query = db.OrderByQuery(query, orderBy) query = db.FieldsQuery(query, fields) query = db.GroupByQuery(query, groupby) query = db.LimitQuery(query, offset, limit) res := query.Find(&objs) if res.Error != nil { if res.Error.Error() == "record not found" { return nil, nil } log.Printf("[GetMulti{{.ModelName}}s]GetMulti{{.ModelName}}s fail, err:%v", res.Error) } return res.Error, objs } {{end}} {{define "getCountTemplate"}} func GetMulti{{.ModelName}}sCount(ctx context.Context,whereMap map[string]interface{}) (error, int64) { cnt := int64(0) query := db.WhereQuery({{.DBConect}}, whereMap) res := query.Model(&{{.ModelPackage}}.{{.ModelName}}{}).Count(&cnt) if res.Error != nil { if res.Error.Error() == "record not found" { return nil, cnt } log.Printf("GetMulti{{.ModelName}}sCount fail, err:%v", res.Error) } return res.Error, cnt } {{end}} {{define "getOneTemplate"}} func GetOne{{.ModelName}}(ctx context.Context,whereMap map[string]interface{},fields string)(error, *{{.ModelPackage}}.{{.ModelName}}) { err, objs := GetMulti{{.ModelName}}s(ctx, whereMap, 0, 1, "", "", fields) if err != nil { return err, nil } if len(objs) >= 1 { return nil, objs[0] } return nil, nil } {{end}}
2.2 部分实现代码
主要是根据表名获取对应的model名称、包名等,再利用模版生成代码。
func SQL2Dao(sql string, tablePrefix, packagePrefix, dbCon string) (string, error) { tblName, err := sql2model.TableNameGetFromSQL(sql, tablePrefix) if err != nil { return "", err } modelPackage := sql2model.ModelPackageGet(tblName, tablePrefix, packagePrefix) packageName := DaoPackageNameGet(tblName, tablePrefix, packagePrefix) df := DaoFile{ PackageName: packageName, ModelPackage: modelPackage, ModelName: sql2model.TableName2ModelName(tblName), DBConect: dbCon, } return daoFileGen(df) }
4. 使用方式
Usage of this program: -dbcon string db connect name -if string File path of the SQL statement that creates the table -op int 1:gen model code 2:gen dao code 3:both (default 1) -pp string package prefix add for go file -sql string SQL statement to create table -tp string table prefix of table name to cut
5. 使用实例
$ go run main.go -if=./test/t_student.sql -dbcon=UserDB -tp="t_" -pp=user -op=3 model code have been write to ./output/user_student.go model code have been write to ./output/user_student_service.go
6. 完整代码
5. 参考