Skip to content

Commit 28952ef

Browse files
committed
add output parameters support
1 parent 6622224 commit 28952ef

File tree

6 files changed

+418
-0
lines changed

6 files changed

+418
-0
lines changed

examples/output/output.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package main
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
)
7+
8+
// this type emulate go-mssqldb's ReturnStatus type, used to get rc from SQL Server stored procedures
9+
// https://github.com/microsoft/go-mssqldb/blob/main/mssql.go
10+
type ReturnStatus int32
11+
12+
func execWithNamedOutputArgs(db *sql.DB, outputArg *string, inputOutputArg *string) (err error) {
13+
_, err = db.Exec("EXEC spWithNamedOutputParameters",
14+
sql.Named("outArg", sql.Out{Dest: outputArg}),
15+
sql.Named("inoutArg", sql.Out{In: true, Dest: inputOutputArg}),
16+
)
17+
if err != nil {
18+
return
19+
}
20+
return
21+
}
22+
23+
func execWithTypedOutputArgs(db *sql.DB, rcArg *ReturnStatus) (err error) {
24+
if _, err = db.Exec("EXEC spWithReturnCode", rcArg); err != nil {
25+
return
26+
}
27+
return
28+
}
29+
30+
func main() {
31+
// @NOTE: the real connection is not required for tests
32+
db, err := sql.Open("mssql", "myconnectionstring")
33+
if err != nil {
34+
panic(err)
35+
}
36+
defer db.Close()
37+
38+
outputArg := ""
39+
inputOutputArg := "abcInput"
40+
41+
if err = execWithNamedOutputArgs(db, &outputArg, &inputOutputArg); err != nil {
42+
panic(err)
43+
}
44+
45+
rcArg := new(ReturnStatus)
46+
if err = execWithTypedOutputArgs(db, rcArg); err != nil {
47+
panic(err)
48+
}
49+
50+
if _, err = fmt.Printf("outputArg: %s, inputOutputArg: %s, rcArg: %d", outputArg, inputOutputArg, *rcArg); err != nil {
51+
panic(err)
52+
}
53+
}

examples/output/output_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
6+
"github.com/DATA-DOG/go-sqlmock"
7+
)
8+
9+
func TestNamedOutputArgs(t *testing.T) {
10+
db, mock, err := sqlmock.New()
11+
if err != nil {
12+
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
13+
}
14+
defer db.Close()
15+
16+
inOutInputValue := "abcInput"
17+
mock.ExpectExec("EXEC spWithNamedOutputParameters").
18+
WithArgs(
19+
sqlmock.NamedOutputArg("outArg", "123Output"),
20+
sqlmock.NamedInputOutputArg("inoutArg", &inOutInputValue, "abcOutput"),
21+
).
22+
WillReturnResult(sqlmock.NewResult(1, 1))
23+
24+
// now we execute our method
25+
outArg := ""
26+
inoutArg := "abcInput"
27+
if err = execWithNamedOutputArgs(db, &outArg, &inoutArg); err != nil {
28+
t.Errorf("error was not expected while updating stats: %s", err)
29+
}
30+
31+
// we make sure that all expectations were met
32+
if err := mock.ExpectationsWereMet(); err != nil {
33+
t.Errorf("there were unfulfilled expectations: %s", err)
34+
}
35+
36+
if outArg != "123Output" {
37+
t.Errorf("unexpected outArg value")
38+
}
39+
40+
if inoutArg != "abcOutput" {
41+
t.Errorf("unexpected inoutArg value")
42+
}
43+
}
44+
45+
func TestTypedOutputArgs(t *testing.T) {
46+
rcArg := new(ReturnStatus) // here we will store the return code
47+
48+
valueConverter := sqlmock.NewPassthroughValueConverter(rcArg) // we need this converter to bypass the default ValueConverter logic that alter original value's type
49+
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(valueConverter))
50+
if err != nil {
51+
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
52+
}
53+
defer db.Close()
54+
55+
rcFromSp := ReturnStatus(123) // simulate the return code from the stored procedure
56+
mock.ExpectExec("EXEC spWithReturnCode").
57+
WithArgs(
58+
sqlmock.TypedOutputArg(&rcFromSp), // using this func we can provide the expected type and value
59+
).
60+
WillReturnResult(sqlmock.NewResult(1, 1))
61+
62+
// now we execute our method
63+
if err = execWithTypedOutputArgs(db, rcArg); err != nil {
64+
t.Errorf("error was not expected while updating stats: %s", err)
65+
}
66+
67+
// we make sure that all expectations were met
68+
if err := mock.ExpectationsWereMet(); err != nil {
69+
t.Errorf("there were unfulfilled expectations: %s", err)
70+
}
71+
72+
if *rcArg != 123 {
73+
t.Errorf("unexpected rcArg value")
74+
}
75+
}

output_args.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package sqlmock
2+
3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
"fmt"
7+
"reflect"
8+
)
9+
10+
type namedInOutValue struct {
11+
Name string
12+
ExpectedInValue interface{}
13+
ReturnedOutValue interface{}
14+
In bool
15+
}
16+
17+
// Match implements the Argument interface, allowing check if the given value matches the expected input value provided using NamedInputOutputArg function.
18+
func (n namedInOutValue) Match(v driver.Value) bool {
19+
out, ok := v.(sql.Out)
20+
21+
return ok && out.In == n.In && (!n.In || reflect.DeepEqual(out.Dest, n.ExpectedInValue))
22+
}
23+
24+
// NamedInputArg can ben used to simulate an output value passed back from the database.
25+
// returnedOutValue can be a value or a pointer to the value.
26+
func NamedOutputArg(name string, returnedOutValue interface{}) interface{} {
27+
return namedInOutValue{
28+
Name: name,
29+
ReturnedOutValue: returnedOutValue,
30+
In: false,
31+
}
32+
}
33+
34+
// NamedInputOutputArg can be used to both check if expected input value is provided and to simulate an output value passed back from the database.
35+
// expectedInValue must be a pointer to the value, returnedOutValue can be a value or a pointer to the value.
36+
func NamedInputOutputArg(name string, expectedInValue interface{}, returnedOutValue interface{}) interface{} {
37+
return namedInOutValue{
38+
Name: name,
39+
ExpectedInValue: expectedInValue,
40+
ReturnedOutValue: returnedOutValue,
41+
In: true,
42+
}
43+
}
44+
45+
type typedOutValue struct {
46+
TypeName string
47+
ReturnedOutValue interface{}
48+
}
49+
50+
// Match implements the Argument interface, allowing check if the given value matches the expected type provided using TypedOutputArg function.
51+
func (n typedOutValue) Match(v driver.Value) bool {
52+
return n.TypeName == fmt.Sprintf("%T", v)
53+
}
54+
55+
// TypeOutputArg can be used to simulate an output value passed back from the database, setting value based on the type.
56+
// returnedOutValue must be a pointer to the value.
57+
func TypedOutputArg(returnedOutValue interface{}) interface{} {
58+
return typedOutValue{
59+
TypeName: fmt.Sprintf("%T", returnedOutValue),
60+
ReturnedOutValue: returnedOutValue,
61+
}
62+
}
63+
64+
func setOutputValues(currentArgs []driver.NamedValue, expectedArgs []driver.Value) {
65+
for _, expectedArg := range expectedArgs {
66+
if outVal, ok := expectedArg.(namedInOutValue); ok {
67+
for _, currentArg := range currentArgs {
68+
if currentArg.Name == outVal.Name {
69+
if sqlOut, ok := currentArg.Value.(sql.Out); ok {
70+
reflect.ValueOf(sqlOut.Dest).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue)))
71+
}
72+
73+
break
74+
}
75+
}
76+
}
77+
78+
if outVal, ok := expectedArg.(typedOutValue); ok {
79+
for _, currentArg := range currentArgs {
80+
if fmt.Sprintf("%T", currentArg.Value) == outVal.TypeName {
81+
reflect.ValueOf(currentArg.Value).Elem().Set(reflect.Indirect(reflect.ValueOf(outVal.ReturnedOutValue)))
82+
83+
break
84+
}
85+
}
86+
}
87+
}
88+
}

passthroughvalueconverter.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package sqlmock
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
)
7+
8+
type PassthroughValueConverter struct {
9+
passthroughTypes []string
10+
}
11+
12+
func NewPassthroughValueConverter(typeSamples ...interface{}) *PassthroughValueConverter {
13+
c := &PassthroughValueConverter{}
14+
15+
for _, sampleValue := range typeSamples {
16+
c.passthroughTypes = append(c.passthroughTypes, fmt.Sprintf("%T", sampleValue))
17+
}
18+
19+
return c
20+
}
21+
22+
func (c *PassthroughValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
23+
valueType := fmt.Sprintf("%T", v)
24+
for _, passthroughType := range c.passthroughTypes {
25+
if valueType == passthroughType {
26+
return v, nil
27+
}
28+
}
29+
30+
return driver.DefaultParameterConverter.ConvertValue(v)
31+
}

sqlmock_go18.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
//go:build go1.8
12
// +build go1.8
23

34
package sqlmock
@@ -250,6 +251,8 @@ func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery,
250251
return expected, expected.err // mocked to return error
251252
}
252253

254+
setOutputValues(args, expected.args)
255+
253256
if expected.rows == nil {
254257
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
255258
}
@@ -332,6 +335,8 @@ func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, e
332335
return expected, expected.err // mocked to return error
333336
}
334337

338+
setOutputValues(args, expected.args)
339+
335340
if expected.result == nil {
336341
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
337342
}

0 commit comments

Comments
 (0)