From 5a17b38900228cac919f1e8a4f9576ba3905ba64 Mon Sep 17 00:00:00 2001 From: Vlad Glushchuk Date: Fri, 6 Sep 2019 10:06:29 +0200 Subject: [PATCH] Add Join, LeftJoin, RightJoin and FullJoin --- README.md | 9 +++------ bench_test.go | 15 ++++++++++++++ stmt.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++- stmt_test.go | 24 ++++++++++++++++++++++ 4 files changed, 97 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bb7ea6b..beff640 100644 --- a/README.md +++ b/README.md @@ -161,9 +161,7 @@ fmt.Printf("Most expensive offer: $%.2f\n", minAmount) #### Joins -There are no helper methods to construct a JOIN clause. - -Consider using the "old style" syntax for INNER JOINs: +There are helper methods to construct a JOIN clause: `Join`, `LeftJoin`, `RightJoin` and `FullJoin`. ```go var ( @@ -177,8 +175,7 @@ err := sqlf.From("offers o"). Select("price").To(&price). Where("is_deleted = false"). // Join - From("products p"). - Where("p.id = o.product_id"). + LeftJoin("products p", "p.id = o.product_id"). // Bind a column from joined table to variable Select("p.name").To(&productName). // Print top 10 offers @@ -200,7 +197,7 @@ var ( name string value string ) -err := sqlf.From("t1 FULL JOIN t2 ON t1.num = t2.num AND t2.value IN (?, ?)", "xxx", "yyy"). +err := sqlf.From("t1 CROSS JOIN t2 ON t1.num = t2.num AND t2.value IN (?, ?)", "xxx", "yyy"). Select("t1.num").To(&num). Select("t1.name").To(&name). Select("t2.value").To(&value). diff --git a/bench_test.go b/bench_test.go index d95386a..e8e2ee6 100644 --- a/bench_test.go +++ b/bench_test.go @@ -221,3 +221,18 @@ func BenchmarkWith(b *testing.B) { q.Close() } } + +func BenchmarkIn(b *testing.B) { + a := make([]interface{}, 50) + for i := 0; i < len(a); i++ { + a[i] = i + 1 + } + b.ResetTimer() + for n := 0; n < b.N; n++ { + q := sqlf.From("orders"). + Select("id"). + Where("status").In(a...) + s = q.String() + q.Close() + } +} diff --git a/stmt.go b/stmt.go index c31d705..19f750c 100644 --- a/stmt.go +++ b/stmt.go @@ -314,7 +314,10 @@ func (q *Stmt) From(expr string, args ...interface{}) *Stmt { /* Where adds a filter: - sqlf.Select("id, name").From("users").Where("email = ?", email).Where("is_active = 1") + sqlf.From("users"). + Select("id, name"). + Where("email = ?", email). + Where("is_active = 1") */ func (q *Stmt) Where(expr string, args ...interface{}) *Stmt { @@ -339,11 +342,45 @@ func (q *Stmt) In(args ...interface{}) *Stmt { } } buf.WriteString(")") + q.addChunk(posWhere, "", bufToString(&buf.B), args, " ") + bytebufferpool.Put(buf) return q } +/* +Join adds an INNERT JOIN clause to SELECT statement +*/ +func (q *Stmt) Join(table, on string) *Stmt { + q.join("JOIN ", table, on) + return q +} + +/* +LeftJoin adds a LEFT OUTER JOIN clause to SELECT statement +*/ +func (q *Stmt) LeftJoin(table, on string) *Stmt { + q.join("LEFT JOIN ", table, on) + return q +} + +/* +RightJoin adds a RIGHT OUTER JOIN clause to SELECT statement +*/ +func (q *Stmt) RightJoin(table, on string) *Stmt { + q.join("RIGHT JOIN ", table, on) + return q +} + +/* +FullJoin adds a FULL OUTER JOIN clause to SELECT statement +*/ +func (q *Stmt) FullJoin(table, on string) *Stmt { + q.join("FULL JOIN ", table, on) + return q +} + // OrderBy adds the ORDER BY clause to SELECT statement func (q *Stmt) OrderBy(expr ...string) *Stmt { q.addChunk(posOrderBy, "ORDER BY", strings.Join(expr, ", "), nil, ", ") @@ -555,6 +592,22 @@ func (q *Stmt) Clone() *Stmt { return stmt } +// join adds a join clause to a SELECT statement +func (q *Stmt) join(joinType, table, on string) (index int) { + buf := bytebufferpool.Get() + buf.WriteString(joinType) + buf.WriteString(table) + buf.Write(joinOn) + buf.WriteString(on) + buf.WriteByte(')') + + index = q.addChunk(posFrom, "", bufToString(&buf.B), nil, " ") + + bytebufferpool.Put(buf) + + return index +} + // addChunk adds a clause or expression to a statement. func (q *Stmt) addChunk(pos int, clause, expr string, args []interface{}, sep string) (index int) { // Remember the position @@ -653,6 +706,7 @@ var ( space = []byte{' '} placeholder = []byte{'?'} placeholderComma = []byte{'?', ','} + joinOn = []byte{' ', 'O', 'N', ' ', '('} ) const ( diff --git a/stmt_test.go b/stmt_test.go index 6817ed8..d2058e3 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -209,3 +209,27 @@ func TestClone(t *testing.T) { assert.NotEqual(t, q.Args(), q2.Args()) assert.NotEqual(t, q.String(), q2.String()) } + +func TestJoin(t *testing.T) { + q := sqlf.From("orders o").Select("id").Join("users u", "u.id = o.user_id") + defer q.Close() + assert.Equal(t, "SELECT id FROM orders o JOIN users u ON (u.id = o.user_id)", q.String()) +} + +func TestLeftJoin(t *testing.T) { + q := sqlf.From("orders o").Select("id").LeftJoin("users u", "u.id = o.user_id") + defer q.Close() + assert.Equal(t, "SELECT id FROM orders o LEFT JOIN users u ON (u.id = o.user_id)", q.String()) +} + +func TestRightJoin(t *testing.T) { + q := sqlf.From("orders o").Select("id").RightJoin("users u", "u.id = o.user_id") + defer q.Close() + assert.Equal(t, "SELECT id FROM orders o RIGHT JOIN users u ON (u.id = o.user_id)", q.String()) +} + +func TestFullJoin(t *testing.T) { + q := sqlf.From("orders o").Select("id").FullJoin("users u", "u.id = o.user_id") + defer q.Close() + assert.Equal(t, "SELECT id FROM orders o FULL JOIN users u ON (u.id = o.user_id)", q.String()) +}