Skip to content
This repository has been archived by the owner on Sep 27, 2024. It is now read-only.

Commit

Permalink
Add DBConn.ExecContext
Browse files Browse the repository at this point in the history
Context allows queries to be cancelled. This can be used when we
expect a SIGINT/SIGTERM could occur during a query that could block
(e.g. trying to SELECT from a table but some other session explicitly
grabbed an AccessExclusiveLock and will not be giving it up anytime
soon). Without the ability to cancel queries, exiting Go programs
could lead to SQL session leaks. Simply closing the SQL connection
DOES NOT cancel the query for you.
  • Loading branch information
jimmyyih committed Oct 11, 2019
1 parent f202136 commit 39fae54
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
14 changes: 14 additions & 0 deletions dbconn/dbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package dbconn
*/

import (
"context"
"database/sql"
"fmt"
"strconv"
Expand Down Expand Up @@ -257,6 +258,19 @@ func (dbconn *DBConn) MustExec(query string, whichConn ...int) {
gplog.FatalOnError(err)
}

func (dbconn *DBConn) ExecContext(queryContext context.Context, query string, whichConn ...int) (sql.Result, error) {
connNum := dbconn.ValidateConnNum(whichConn...)
if dbconn.Tx[connNum] != nil {
return dbconn.Tx[connNum].ExecContext(queryContext, query)
}
return dbconn.ConnPool[connNum].ExecContext(queryContext, query)
}

func (dbconn *DBConn) MustExecContext(queryContext context.Context, query string, whichConn ...int) {
_, err := dbconn.ExecContext(queryContext, query, whichConn...)
gplog.FatalOnError(err)
}

func (dbconn *DBConn) GetWithArgs(destination interface{}, query string, args ...interface{}) error {
if dbconn.Tx[0] != nil {
return dbconn.Tx[0].Get(destination, query, args...)
Expand Down
31 changes: 31 additions & 0 deletions dbconn/dbconn_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbconn_test

import (
"context"
"database/sql/driver"
"fmt"
"os"
Expand Down Expand Up @@ -183,6 +184,36 @@ var _ = Describe("dbconn/dbconn tests", func() {
Expect(rowsReturned).To(Equal(int64(1)))
})
})
Describe("DBConn.ExecContext", func() {
It("executes an INSERT outside of a transaction", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

fakeResult := testhelper.TestResult{Rows: 1}
mock.ExpectExec("INSERT (.*)").WillReturnResult(fakeResult)

res, err := connection.ExecContext(ctx, "INSERT INTO pg_tables VALUES ('schema', 'table')")
Expect(err).ToNot(HaveOccurred())
rowsReturned, err := res.RowsAffected()
Expect(rowsReturned).To(Equal(int64(1)))
})
It("executes an INSERT in a transaction", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

fakeResult := testhelper.TestResult{Rows: 1}
ExpectBegin(mock)
mock.ExpectExec("INSERT (.*)").WillReturnResult(fakeResult)
mock.ExpectCommit()

connection.MustBegin()
res, err := connection.ExecContext(ctx, "INSERT INTO pg_tables VALUES ('schema', 'table')")
connection.MustCommit()
Expect(err).ToNot(HaveOccurred())
rowsReturned, err := res.RowsAffected()
Expect(rowsReturned).To(Equal(int64(1)))
})
})
Describe("DBConn.Get", func() {
It("executes a GET outside of a transaction", func() {
two_col_single_row := sqlmock.NewRows([]string{"schemaname", "tablename"}).
Expand Down

0 comments on commit 39fae54

Please sign in to comment.