1. 场景
当前golang开发人员,在编写完成代码后,通常会写对应的单测来保证代码的健壮。对于很多大厂来说,编写单测已经是代码规范的一部分。基于官方提供的gomock
框架和mockgen
辅助工具就可以满足绝大部分场景,对于不能直接创建的依赖进行mock。但是,当我们编写API接口的时候,往往会对数据库进行操作,那么就需要支持对SQL进行mock的场景。
2. sqlmock 简介
在使用gorm等orm框架时,由于需要和数据库进行交互,并且CICD服务器在对代码检测的时候,往往也无法连接真正的数据库,因此编写单元测试,就会变得很困难。
go-sqlmock
本质是一个实现了 sql/driver 接口的 mock 库,它的设计目标是支持在测试中,模拟任何 sql driver 的行为,而不需要一个真正的数据库连接。因此,可以很好的解决这个问题。
3. 安装 go-sqlmock
go get github.com/DATA-DOG/go-sqlmock
4. sqlmock实战
首先我们模拟一下,在实际开发中会使用到gorm来对数据库查询操作。
目录结构:
- main.go: 主程序,加载
TagController
,并注入已经初始化后的*gorm.DB
, 然后调用TagController
中的方法PrintTagList()
- controller
- tag.go: 包含控制器
TagController
的代码
- tag.go: 包含控制器
- model
- tag.go: 包含model层,使用gorm需要定义的tag表的字段信息
4.1 定义接口
4.1.1 main.go
这里省去了,我们可能会用到的gin
等框架负载的启动逻辑。假设main
函数中,就是单纯的初始化gorm
,并实例化控制器后,调用控制器的方法,获取数据库中的结果。
dsn连接信息,这里预设的是本地的数据库连接信息。
package main
import (
"test/utils/sqlmock/controller"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
func main() {
db := initDB()
tagCtrl := controller.TagController{
DB: db,
}
tagCtrl.PrintTagList()
}
func initDB() *gorm.DB {
dsn := "root:@tcp(127.0.0.1:3306)/registration-service?charset=utf8&parseTime=true&loc=Asia%2FShanghai"
db, err := gorm.Open(mysql.Open(dsn))
if err != nil {
panic(err)
}
return db
}
4.1.2 controller/tag.go
类似于一般的框架,MVC架构下,通常会首先进入controller中,然后通过controller来访问model层的代码。这里提供了,TagController
的PrintTagList()
方法,来打印所有从数据库中查询出来的TagName
package controller
import (
"fmt"
"test/utils/sqlmock/model"
"gorm.io/gorm"
)
type TagController struct {
DB *gorm.DB
}
func (c *TagController) PrintTagList() {
var tagModel []*model.Tag
if err := c.DB.Find(&tagModel).Error; err != nil {
fmt.Println(err)
}
for _, tag := range tagModel {
fmt.Println(tag.TagName)
}
}
4.1.3 model/tag.go
MVC 中的model层代码,这里是按照gorm的使用规范,定义了Tag表
的结构信息。
package model
import (
"time"
)
// Tag 表
type Tag struct {
Id uint `gorm:"column:id;type:int(11) unsigned;primary_key;AUTO_INCREMENT" json:"id"`
TagName string `gorm:"column:tag_name;type:varchar(20);comment:关键字;NOT NULL" json:"tag_name"`
Enabled int32 `gorm:"column:enabled;type:tinyint(2);default:0;comment:是否启用:1启用,0禁用;NOT NULL" json:"enabled"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;comment:创建时间" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;comment:更新时间" json:"updated_at"`
}
// TableName -
func (m *Tag) TableName() string {
return "tag"
}
4.1.4 执行main.go
当然,实际的数据,已经预先写入到了数据库中。这里可以正确的被打印出来
结果:
tag1
tag2
apple
orange
water
banana
4.2 通过sqlmock来对TagController
的代码编写单测
创建controller/tag_test.go
的单测文件,填写以下信息:
package controller
import (
"fmt"
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/DATA-DOG/go-sqlmock"
)
// 初始化sqlmock
func initTest() (*gorm.DB, sqlmock.Sqlmock) {
// 1. 初始化 sql mock
db, mock, err := sqlmock.New()
if err != nil {
fmt.Println("err:", err)
}
// 2. mock数据库版本查询
mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow("5.7.25"))
// 3. mock gorm driver
gormDB, err := gorm.Open(mysql.New(mysql.Config{
Conn: db,
}), &gorm.Config{})
if err != nil {
panic(err) // Error here
}
return gormDB, mock
}
// 对PrintTagList方法单测
func TestTagController_PrintTagList(t *testing.T) {
// 初始化sqlmock
gormDB, sqlMock := initTest()
// 对即将产生的sql,预先打桩处理
mockExpect(sqlMock)
// 初始化控制器,并将mock后的gorm注入
tagController := &TagController{
DB: gormDB,
}
// 调用需要测试的方法
tagController.PrintTagList()
}
// 对即将产生的sql,预先打桩处理
func mockExpect(mock sqlmock.Sqlmock) {
mock.ExpectQuery("^SELECT (.+) FROM `tag`").WillReturnRows(sqlmock.NewRows([]string{"tag_name"}).AddRow("apple").AddRow("orange"))
}
在执行结果中,将会显mock的内容AddRow("apple").AddRow("orange")
// 执行结果:
=== RUN TestTagController_PrintTagList
apple
orange
--- PASS: TestTagController_PrintTagList (0.00s)
PASS
4.3 支持事务
4.3.1 在TagController
中增加方法Create()
func (c *TagController) Create(tagName string) {
// 开启事务
tx := c.DB.Begin()
tagModel := &model.Tag{
TagName: tagName,
Enabled: 1,
CreatedAt: time.Now(),
}
if err := tx.Create(tagModel).Error; err != nil {
// 创建失败回滚
tx.Rollback()
fmt.Println(err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
fmt.Println(err)
}
}
4.3.2 增加单测
func TestTagController_Create(t *testing.T) {
gormDB, sqlMock := initTest()
mockCreateExpect(sqlMock)
tagController := &TagController{
DB: gormDB,
}
// 测试创建失败
tagController.Create("banana")
// 测试创建成功
tagController.Create("banana")
}
func mockCreateExpect(mock sqlmock.Sqlmock) {
// mock创建失败
mock.ExpectBegin()
mock.ExpectExec("^INSERT INTO `tag` ").WillReturnError(gorm.ErrInvalidData)
mock.ExpectRollback()
// mock创建成功
mock.ExpectBegin()
mock.ExpectExec("^INSERT INTO `tag` ").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
}
其中sqlmock.AnyArg()
跳过对参数的匹配校验,但是.WithArgs()
方法要求,对参数的数量需要一致。
4.4. 完整代码
controller/tag.go
package controller
import (
"fmt"
"test/utils/sqlmock/model"
"time"
"gorm.io/gorm"
)
type TagController struct {
DB *gorm.DB
}
func (c *TagController) PrintTagList() {
var tagModel []*model.Tag
if err := c.DB.Find(&tagModel).Error; err != nil {
fmt.Println(err)
}
for _, tag := range tagModel {
fmt.Println(tag.TagName)
}
}
func (c *TagController) Create(tagName string) {
// 开启事务
tx := c.DB.Begin()
tagModel := &model.Tag{
TagName: tagName,
Enabled: 1,
CreatedAt: time.Now(),
}
if err := tx.Create(tagModel).Error; err != nil {
// 创建失败回滚
tx.Rollback()
fmt.Println(err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
fmt.Println(err)
}
}
controller/tag_test.go
package controller
import (
"fmt"
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/DATA-DOG/go-sqlmock"
)
func initTest() (*gorm.DB, sqlmock.Sqlmock) {
// 1. 初始化 sql mock
db, mock, err := sqlmock.New()
if err != nil {
fmt.Println("err:", err)
}
// 2. mock数据库版本查询
mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow("5.7.25"))
// 3. 组装mock的gorm
gormDB, err := gorm.Open(mysql.New(mysql.Config{
Conn: db,
}), &gorm.Config{})
if err != nil {
panic(err) // Error here
}
return gormDB, mock
}
func TestTagController_PrintTagList(t *testing.T) {
gormDB, sqlMock := initTest()
mockExpect(sqlMock)
tagController := &TagController{
DB: gormDB,
}
tagController.PrintTagList()
}
func TestTagController_Create(t *testing.T) {
gormDB, sqlMock := initTest()
mockCreateExpect(sqlMock)
tagController := &TagController{
DB: gormDB,
}
// 测试创建失败
tagController.Create("banana")
// 测试创建成功
tagController.Create("banana")
}
func mockExpect(mock sqlmock.Sqlmock) {
mock.ExpectQuery("^SELECT (.+) FROM `tag`").WillReturnRows(sqlmock.NewRows([]string{"tag_name"}).AddRow("apple").AddRow("orange"))
}
func mockCreateExpect(mock sqlmock.Sqlmock) {
// mock创建失败
mock.ExpectBegin()
mock.ExpectExec("^INSERT INTO `tag` ").WillReturnError(gorm.ErrInvalidData)
mock.ExpectRollback()
// mock创建成功
mock.ExpectBegin()
mock.ExpectExec("^INSERT INTO `tag` ").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
}
4.5 注意事项
- 在
initTest()
方法中,对gorm driver进行mock的时候,低版本和高版本的代码实现会有一定的差异。目前网上搜索到的示例大多数都是旧版本的实现方式,本文中的示例,是基于gorm.io/gorm v1.25.5
版本实现的。 - 在
mock.ExpectQuery()
方法中,支持正则表达式来对sql语句进行匹配。 - 初始化数据库,
SELECT VERSION()
问题
[error] failed to initialize database, got error all expectations were already fulfilled, call to Query 'SELECT VERSION()' with args [] was not expected
--- FAIL: TestTagController_Create (0.00s)
panic: all expectations were already fulfilled, call to Query 'SELECT VERSION()' with args [] was not expected [recovered]
panic: all expectations were already fulfilled, call to Query 'SELECT VERSION()' with args [] was not expected
需要增加Expect:mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow("5.7.25"))
4. 数据库连接关闭问题
sql: database is closed
sql: database is closed; invalid transaction
通过db, mock, err := sqlmock.New()
获取到db后,千万不要defer db.Close()
,否则会导致后续对数据库操作,引起database is closed
的问题。
// 1. 初始化 sql mock
db, mock, err := sqlmock.New()
// defer db.Close() 不要加这一行
if err != nil {
fmt.Println("err:", err)
}
5. 总结
上面主要是,简单的介绍和示例了,通过sqlmock来对gorm打桩mock。从而更加简单和方便的来对使用到数据库操作的业务代码进行单测的编写。
评论区