目录
单元测试,顾名思义对某个单元函数进行测试,被测函数本身中用到的变量、函数、资源不应被测试代码依赖,所谓 mock,就是想办法通过 “虚拟” 代码替换掉依赖的方法和资源,一般需要 mock 掉以下依赖:
变量
函数/方法
MySQL
Redis
http 调用
有时我们的代码里依赖一个全局变量,测试方法根据全局变量的不同值执行不同的逻辑,那么可以用 gostub 对变量进行打桩。
global.go:
- package main
-
- var size = 5
-
- func Size() int {
- if size > 10 {
- return 10
- }
- return size
- }
- package main
-
- import (
- "testing"
-
- "github.com/agiledragon/gomonkey/v2"
- "github.com/prashantv/gostub"
- )
-
- func TestSizeStub(t *testing.T) {
- tests := []struct {
- name string
- want int
- f func() *gostub.Stubs
- }{
- {name: "size > 10", want: 10, f: func() *gostub.Stubs {
- return gostub.Stub(&size, 11)
- }},
- {name: "size <= 10", want: 3, f: func() *gostub.Stubs {
- return gostub.Stub(&size, 3)
- }},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stub := tt.f()
- if got := Size(); got != tt.want {
- t.Errorf("Size() = %v, want %v", got, tt.want)
- }
- stub.Reset()
- })
- }
- }
-
- func TestSizeMonkey(t *testing.T) {
- tests := []struct {
- name string
- want int
- f func() *gomonkey.Patches
- }{
- {name: "size > 10", want: 10, f: func() *gomonkey.Patches {
- return gomonkey.ApplyGlobalVar(&size, 11)
- }},
- {name: "size <= 10", want: 3, f: func() *gomonkey.Patches {
- return gomonkey.ApplyGlobalVar(&size, 3)
- }},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stub := tt.f()
- if got := Size(); got != tt.want {
- t.Errorf("Size() = %v, want %v", got, tt.want)
- }
- stub.Reset()
- })
- }
- }
- $ go test -v -cover
- === RUN TestSize
- === RUN TestSize/size_>_10
- === RUN TestSize/size_<=_10
- --- PASS: TestSize (0.00s)
- --- PASS: TestSize/size_>_10 (0.00s)
- --- PASS: TestSize/size_<=_10 (0.00s)
- PASS
- coverage: 100.0% of statements
首先 Go 语言推荐的是面向接口编程,所以官方提供并推荐使用 gomock 对依赖的方法进行 mock,前提是依赖的方法是通过抽象接口实现的,gomock 执行过程如下:
使用mockgen为你想要mock的接口生成一个mock。
在你的测试代码中,创建一个gomock.Controller实例并把它作为参数传递给mock对象的构造函数来创建一个mock对象。
调用EXPECT()为你的mock对象设置各种期望和返回值。
调用mock控制器的Finish()以验证mock的期望行为。
gomock 常用方法:
| 类型 | 用法 | 作用 |
| 参数 | gomock.Any(v) | 匹配任何类型 |
| gomock.Eq(v) | 匹配使用反射 reflect.DeepEqual 与 v 相等的值 | |
| gomock.Not(v) | v 不是 Matcher 时,匹配使用反射 reflect.DeepEqual 与 v 不相等的值;v 是 Matcher 时,匹配和 Macher 不匹配的值(Matcher) | |
| gomock.Nil() | 匹配等于 nil 的值 | |
| 返回 | Return() | mock 方法返回值 |
| Do(func) | 传入的 func 在 mock 真正被调用时自动执行,忽略 Return,比如:对调用方法的参数进行校验 | |
| DoAndReturn(func) | 传入的 func 在 mock 真正被调用时自动执行,对应 func 返回值作为 mock 方法返回值 | |
| 调用次数 | AnyTimes(n int) | mock 方法可以被调用任意次数,一次不调用也不会失败(这里大家可以自检一下各自的单测代码,用这个方法的单测可能并没有按照预期运行) |
| Times() | mock 方法被调用次数,次数不相等运行失败 | |
| MaxTimes(n int) | mock 方法被调用次数,大于规定次数运行失败 | |
| MinTimes(n int) | mock 方法被调用次数,小于规定次数运行失败 | |
| 调用排序 | gomock.InOrder( first.EXPECT.Func().Return(), second.EXPECT.Func().Return(), thrid.EXPECT.Func().Return(), ) | 规定多个 mock 方法的调用顺序,顺序不符运行失败 |
| first := rc.EXPECT().DoFucn() second := rc.EXPECT().DoFunc().After(first) | 规定多个 mock 方法的先后依赖关系,顺序不符运行失败 |
首先通过 mockgen 生成 Redis Client 的 mock 代码:
- $ go get -u github.com/golang/mock/gomock
- $ go install github.com/golang/mock/mockgen
-
- 本地interface:
- mockgen[go run -mod=mod github.com/golang/mock/mockgen -package mock] -source ~/go/pkg/mod/github.com/opentracing/opentracing-go\@v1.2.0/tracer.go -destination ./opentracing/tracer.go Tracer
- 远端interface:
- go run -mod=mod github.com/golang/mock/mockgen -package redis -destination ./mock/redis/redis.go github.com/go-redis/redis/v8 Cmdable
redis.go:
- package main
-
- import (
- "context"
-
- "github.com/go-redis/redis/v8"
- )
-
- func handleRedis(c redis.Cmdable) (string, error) {
- return c.Get(context.Background(), "redis").Result()
- }
-
- func conn() *redis.Client {
- return redis.NewClient(&redis.Options{Addr: "127.0.0.1:6379"})
- }
- package main
-
- import (
- "testing"
-
- "github.com/go-redis/redis/v8"
- "github.com/golang/mock/gomock"
- )
-
- func Test_handleRedis(t *testing.T) {
- ctl := gomock.NewController(t)
- defer ctl.Finish()
-
- c := NewMockCmdable(ctl)
- c.EXPECT().Get(gomock.Any(), gomock.Any()).Times(1).Return(redis.NewStringResult("redis", nil))
-
- handleRedis(c)
- }
假如我们依赖的其他人写的方法,并不是通过接口实现的,无法使用 gomock 时,可以用 gomonkey 进行打桩
常用函数:
gomonkey.ApplyFunc():单个包函数打桩
gomonkey.ApplyFuncSeq():连续多个包函数打桩
func.go
- package main
-
- func A() int {
- return B()
- }
-
- func AA() int {
- return B() + B()
- }
-
- func B() int {
- return 0
- }
- package main
-
- import (
- "testing"
-
- "github.com/agiledragon/gomonkey/v2"
- "github.com/stretchr/testify/assert"
- )
-
- // TestA 函数,单次打桩
- func TestA(t *testing.T) {
- patch := gomonkey.ApplyFunc(B, func() int {
- return 1
- })
- defer patch.Reset()
-
- assert.Equal(t, 1, A())
- }
-
- // TestAA 函数,连续打桩
- func TestAA(t *testing.T) {
- patch := gomonkey.ApplyFuncSeq(B, []gomonkey.OutputCell{
- {Values: gomonkey.Params{1}},
- {Values: gomonkey.Params{2}},
- })
- defer patch.Reset()
-
- assert.Equal(t, 3, AA())
- }
常用函数:
gomonkey.ApplyMethod():单个公有成员方法打桩
patch.ApplyPrivateMethod():单个私有成员方法打桩
patch.ApplyMethodSeq():连续多个公有成员方法打桩
gomonkey.ApplyFuncSeq():连续多个私有成员方法打桩
method.go
- package main
-
- type S struct{}
-
- func (s *S) A() int {
- return s.B() + s.b()
- }
-
- func (s *S) AA() int {
- return s.B() + s.b() + s.B() + s.b()
- }
-
- func (s *S) B() int {
- return 0
- }
-
- func (s *S) b() int {
- return 0
- }
- package main
-
- import (
- "reflect"
- "testing"
-
- "github.com/agiledragon/gomonkey/v2"
- "github.com/stretchr/testify/assert"
- )
-
- // TestS_AA 成员方法单个打桩
- func TestS_A(t *testing.T) {
- s := &S{}
-
- // 公共成员方法
- patch := gomonkey.ApplyMethod(reflect.TypeOf(s), "B", func(_ *S) int {
- return 1
- })
- // 私有成员方法
- patch.ApplyPrivateMethod(reflect.TypeOf(s), "b", func(_ *S) int {
- return 2
- })
- defer patch.Reset()
-
- assert.Equal(t, 3, s.A())
- }
-
- // TestS_AA 成员方法连续打桩
- func TestS_AA(t *testing.T) {
- s := &S{}
-
- // 私有成员方法
- patch := gomonkey.ApplyFuncSeq((*S).b, []gomonkey.OutputCell{
- {Values: gomonkey.Params{1}},
- {Values: gomonkey.Params{2}},
- })
- // 公共成员方法
- patch.ApplyMethodSeq(reflect.TypeOf(s), "B", []gomonkey.OutputCell{
- {Values: gomonkey.Params{1}},
- {Values: gomonkey.Params{2}},
- })
- defer patch.Reset()
-
- assert.Equal(t, 6, s.AA())
- }
db.go
- package main
-
- import (
- "database/sql"
- "encoding/json"
- "fmt"
-
- _ "github.com/go-sql-driver/mysql"
- "github.com/jmoiron/sqlx"
- )
-
- const dsn = "root:123456@tcp(127.0.0.1:3306)/test"
-
- type Test struct {
- ID int64 `json:"id" db:"id" gorm:"column:id"`
- GoodsID int64 `json:"goodsID" db:"goods_id" gorm:"column:goods_id"`
- Name string `json:"name" db:"name" gorm:"column:name"`
- }
-
- func (Test) TableName() string {
- return "test"
- }
-
- func handle(db *sql.DB) (err error) {
- tx, err := db.Begin()
- if err != nil {
- return
- }
-
- defer func() {
- switch err {
- case nil:
- err = tx.Commit()
- default:
- tx.Rollback()
- }
- }()
-
- rows, err := tx.Query("SELECT * from test where id > ?", 0)
- if err != nil {
- panic(err)
- }
- result := []Test{}
- if err = sqlx.StructScan(rows, &result); err != nil {
- panic(err)
- }
-
- b, err := json.Marshal(result)
- if err != nil {
- panic(err)
- }
- fmt.Println("sql:", string(b))
-
- if _, err = tx.Exec("UPDATE test SET goods_id = goods_id + 1 where id = 2"); err != nil {
- return
- }
- if _, err = tx.Exec("INSERT INTO test (goods_id, name) VALUES (?, ?)", 1, "1"); err != nil {
- return
- }
- return
- }
-
- func main() {
- db, err := sql.Open("mysql", dsn)
- if err != nil {
- panic(err)
- }
- defer db.Close()
-
- if err = handle(db); err != nil {
- panic(err)
- }
- }
- package main
-
- import (
- "log"
- "os"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- _ "github.com/go-sql-driver/mysql"
- "github.com/stretchr/testify/assert"
- )
-
- func Test_handle(t *testing.T) {
- db, mock, err := sqlmock.New()
- if err != nil {
- panic(err)
- }
-
- mock.ExpectBegin()
- // (.+) 用于替代字段,可用于 select、order、group等
- mock.ExpectQuery("SELECT (.+) from test where id > ?").WillReturnRows(sqlmock.NewRows([]string{"id", "goods_id", "name"}).AddRow(1, 1, "1"))
- // sql前缀匹配
- mock.ExpectExec("UPDATE test SET goods_id").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("INSERT INTO test").WithArgs(1, "1").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- if err = handle(db); err != nil {
- panic(err)
- }
-
- if err = mock.ExpectationsWereMet(); err != nil {
- panic(err)
- }
- }
如果遇到如下错误:
- /usr/local/go16/pkg/tool/linux_amd64/link: running gcc failed: exit status 1
- /usr/bin/ld: /tmp/go-link-866330658/000020.o(.text+0x74): unresolvable H��@�>H��FH��H��H��@�~�F�H��@�~H��8�H��H��0�FH��H��(�FH��H�� �FH��H���FH��H���FH��H��F�fD relocation against symbol `stderr@@GLIBC_2.2.5'
- /usr/bin/ld: BFD version 2.20.51.0.2-5.34.el6 20100205 internal error, aborting at reloc.c line 443 in bfd_get_reloc_size
- /usr/bin/ld: Please report this bug.
- collect2: ld returned 1 exit status
- 更新 go env gcc 版本:
- go env -w CC=/opt/compiler/gcc-8.2/bin/gcc
- go env -w CXX=/opt/compiler/gcc-8.2/bin/g++
- 或
- CC=/opt/compiler/gcc-8.2/bin/gcc CXX=/opt/compiler/gcc-8.2/bin/g++ go test -c -cover
db.go
- package main
-
- import (
- "database/sql"
- "encoding/json"
- "fmt"
-
- "gorm.io/driver/mysql"
- "gorm.io/gorm"
- )
-
- const dsn = "root:123456@tcp(127.0.0.1:3306)/test"
-
- type Test struct {
- ID int64 `json:"id" db:"id" gorm:"column:id"`
- GoodsID int64 `json:"goodsID" db:"goods_id" gorm:"column:goods_id"`
- Name string `json:"name" db:"name" gorm:"column:name"`
- }
-
- func (Test) TableName() string {
- return "test"
- }
-
- func main() {
- orm, err := gorm.Open(mysql.Open(dsn))
- if err != nil {
- panic(err)
- }
-
- handleOrm(orm)
- }
-
- func handleOrm(orm *gorm.DB) {
- var rows []Test
-
- clause := func(db *gorm.DB) *gorm.DB {
- return db.Where("id >= ?", 1)
- }
- err := clause(orm.Select("*")).Find(&rows).Error
- if err != nil {
- panic(err)
- }
-
- b, err := json.Marshal(rows)
- if err != nil {
- panic(err)
- }
- fmt.Println("gorm", string(b))
- }
- package main
-
- import (
- "log"
- "os"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "gorm.io/driver/sqlite"
- "gorm.io/gorm"
- "gorm.io/gorm/logger"
- )
-
- func Test_handleOrm(t *testing.T) {
- db := NewMemoryDB()
- err := db.Migrator().CreateTable(&Test{})
- assert.Nil(t, err)
-
- handleOrm(db)
- }
-
- func NewMemoryDB() *gorm.DB {
- var db *gorm.DB
- var err error
- newLogger := logger.New(
- log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
- logger.Config{
- SlowThreshold: time.Second, // 慢 SQL 阈值
- LogLevel: logger.Info, // Log level
- Colorful: false, // 禁用彩色打印
- },
- )
- dialector := sqlite.Open(":memory:?cache=shared")
- if db, err = gorm.Open(dialector, &gorm.Config{
- Logger: newLogger,
- }); err != nil {
- panic(err)
- }
- dba, err := db.DB()
- dba.SetMaxOpenConns(1)
- return db
- }
-
- func CloseMemoryDB(db *gorm.DB) {
- sqlDB, _ := db.DB()
- sqlDB.Close()
- }
http.go
- package main
-
- import (
- "fmt"
- "net/http"
- "time"
- )
-
- func Send() (err error) {
- req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:8080", nil)
- if err != nil {
- return
- }
- client := &http.Client{
- Timeout: time.Second,
- }
- resp, err := client.Do(req)
- if err != nil {
- return
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("HTTP status is %d", resp.StatusCode)
- }
-
- return
- }
- package main
-
- import (
- "net/http"
- "testing"
-
- "github.com/jarcoal/httpmock"
- "github.com/smartystreets/goconvey/convey"
- "github.com/stretchr/testify/assert"
- )
-
- func TestSend(t *testing.T) {
- convey.Convey("TestSend", t, func() {
- convey.Convey("success", func() {
- httpmock.Activate()
- defer httpmock.DeactivateAndReset()
- httpmock.RegisterResponder(http.MethodGet, "https://127.0.0.1:8080", httpmock.NewStringResponder(http.StatusOK, ""))
-
- err := Send()
- assert.Nil(t, err)
- })
- })
- }