Skip to content

Commit 1cc9fb9

Browse files
committed
fix for parallel calls with prepared stmts
1 parent b2f0b45 commit 1cc9fb9

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

sqlmock.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,14 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
296296
if next.fulfilled() {
297297
next.Unlock()
298298
fulfilled++
299+
300+
if pr, ok := next.(*ExpectedPrepare); ok {
301+
if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
302+
expected = pr
303+
next.Lock()
304+
break
305+
}
306+
}
299307
continue
300308
}
301309

@@ -334,6 +342,14 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
334342
}
335343

336344
func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
345+
for _, e := range c.expected {
346+
if ep, ok := e.(*ExpectedPrepare); ok {
347+
if ep.expectSQL == expectedSQL {
348+
return ep
349+
}
350+
}
351+
}
352+
337353
e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
338354
c.expected = append(c.expected, e)
339355
return e

sqlmock_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,50 @@ func TestUnorderedPreparedQueryExecutions(t *testing.T) {
394394
}
395395
}
396396

397+
func TestParallelPreparedQueryExecutions(t *testing.T) {
398+
db, mock, err := New()
399+
if err != nil {
400+
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
401+
}
402+
mock.MatchExpectationsInOrder(false)
403+
404+
mock.ExpectPrepare("INSERT INTO authors \\((.+)\\) VALUES \\((.+)\\)").
405+
ExpectExec().
406+
WithArgs(1, "Jane Doe").
407+
WillReturnResult(NewResult(1, 1))
408+
409+
mock.ExpectPrepare("INSERT INTO authors \\((.+)\\) VALUES \\((.+)\\)").
410+
ExpectExec().
411+
WithArgs(0, "John Doe").
412+
WillReturnResult(NewResult(0, 1))
413+
414+
t.Run("Parallel1", func(t *testing.T) {
415+
t.Parallel()
416+
417+
stmt, err := db.Prepare("INSERT INTO authors (id, name) VALUES (?, ?)")
418+
if err != nil {
419+
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
420+
} else {
421+
_, err = stmt.Exec(0, "John Doe")
422+
}
423+
})
424+
425+
t.Run("Parallel2", func(t *testing.T) {
426+
t.Parallel()
427+
428+
stmt, err := db.Prepare("INSERT INTO authors (id, name) VALUES (?, ?)")
429+
if err != nil {
430+
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
431+
} else {
432+
_, err = stmt.Exec(1, "Jane Doe")
433+
}
434+
})
435+
436+
t.Cleanup(func() {
437+
db.Close()
438+
})
439+
}
440+
397441
func TestUnexpectedOperations(t *testing.T) {
398442
t.Parallel()
399443
db, mock, err := New()
@@ -632,6 +676,89 @@ func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) {
632676
// note this line is important for unordered expectation matching
633677
mock.MatchExpectationsInOrder(false)
634678

679+
data := []interface{}{
680+
1,
681+
"John Doe",
682+
2,
683+
"Jane Doe",
684+
}
685+
rows := NewRows([]string{"id", "name"})
686+
rows.AddRow(data[0], data[1])
687+
rows.AddRow(data[2], data[3])
688+
689+
mock.ExpectExec("DROP TABLE IF EXISTS author").WillReturnResult(NewResult(0, 0))
690+
mock.ExpectExec("TRUNCATE TABLE").WillReturnResult(NewResult(0, 0))
691+
692+
mock.ExpectExec("CREATE TABLE IF NOT EXISTS author").WillReturnResult(NewResult(0, 0))
693+
694+
mock.ExpectQuery("SELECT").WillReturnRows(rows).WithArgs()
695+
696+
mock.ExpectPrepare("INSERT INTO").
697+
ExpectExec().
698+
WithArgs(
699+
data[0],
700+
data[1],
701+
data[2],
702+
data[3],
703+
).
704+
WillReturnResult(NewResult(0, 2))
705+
706+
var wg sync.WaitGroup
707+
queries := []func() error{
708+
func() error {
709+
_, err := db.Exec("CREATE TABLE IF NOT EXISTS author (a varchar(255)")
710+
return err
711+
},
712+
func() error {
713+
_, err := db.Exec("TRUNCATE TABLE author")
714+
return err
715+
},
716+
func() error {
717+
stmt, err := db.Prepare("INSERT INTO author (id,name) VALUES (?,?),(?,?)")
718+
if err != nil {
719+
return err
720+
}
721+
_, err = stmt.Exec(1, "John Doe", 2, "Jane Doe")
722+
return err
723+
},
724+
func() error {
725+
_, err := db.Query("SELECT * FROM author")
726+
return err
727+
},
728+
func() error {
729+
_, err := db.Exec("DROP TABLE IF EXISTS author")
730+
return err
731+
},
732+
}
733+
734+
wg.Add(len(queries))
735+
for _, f := range queries {
736+
go func(f func() error) {
737+
if err := f(); err != nil {
738+
t.Errorf("error was not expected: %s", err)
739+
}
740+
wg.Done()
741+
}(f)
742+
}
743+
744+
wg.Wait()
745+
746+
if err := mock.ExpectationsWereMet(); err != nil {
747+
t.Errorf("there were unfulfilled expectations: %s", err)
748+
}
749+
}
750+
751+
func TestGoroutineExecutionMultiTypes(t *testing.T) {
752+
t.Parallel()
753+
db, mock, err := New()
754+
if err != nil {
755+
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
756+
}
757+
defer db.Close()
758+
759+
// note this line is important for unordered expectation matching
760+
mock.MatchExpectationsInOrder(false)
761+
635762
result := NewResult(1, 1)
636763

637764
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)

0 commit comments

Comments
 (0)