Skip to content

Commit

Permalink
Add Join, LeftJoin, RightJoin and FullJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
leporo committed Sep 6, 2019
1 parent 6bb7952 commit 5a17b38
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 7 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ fmt.Printf("Most expensive offer: $%.2f\n", minAmount)

#### Joins

There are no helper methods to construct a JOIN clause.

Consider using the "old style" syntax for INNER JOINs:
There are helper methods to construct a JOIN clause: `Join`, `LeftJoin`, `RightJoin` and `FullJoin`.

```go
var (
Expand All @@ -177,8 +175,7 @@ err := sqlf.From("offers o").
Select("price").To(&price).
Where("is_deleted = false").
// Join
From("products p").
Where("p.id = o.product_id").
LeftJoin("products p", "p.id = o.product_id").
// Bind a column from joined table to variable
Select("p.name").To(&productName).
// Print top 10 offers
Expand All @@ -200,7 +197,7 @@ var (
name string
value string
)
err := sqlf.From("t1 FULL JOIN t2 ON t1.num = t2.num AND t2.value IN (?, ?)", "xxx", "yyy").
err := sqlf.From("t1 CROSS JOIN t2 ON t1.num = t2.num AND t2.value IN (?, ?)", "xxx", "yyy").
Select("t1.num").To(&num).
Select("t1.name").To(&name).
Select("t2.value").To(&value).
Expand Down
15 changes: 15 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,18 @@ func BenchmarkWith(b *testing.B) {
q.Close()
}
}

func BenchmarkIn(b *testing.B) {
a := make([]interface{}, 50)
for i := 0; i < len(a); i++ {
a[i] = i + 1
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
q := sqlf.From("orders").
Select("id").
Where("status").In(a...)
s = q.String()
q.Close()
}
}
56 changes: 55 additions & 1 deletion stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ func (q *Stmt) From(expr string, args ...interface{}) *Stmt {
/*
Where adds a filter:
sqlf.Select("id, name").From("users").Where("email = ?", email).Where("is_active = 1")
sqlf.From("users").
Select("id, name").
Where("email = ?", email).
Where("is_active = 1")
*/
func (q *Stmt) Where(expr string, args ...interface{}) *Stmt {
Expand All @@ -339,11 +342,45 @@ func (q *Stmt) In(args ...interface{}) *Stmt {
}
}
buf.WriteString(")")

q.addChunk(posWhere, "", bufToString(&buf.B), args, " ")

bytebufferpool.Put(buf)
return q
}

/*
Join adds an INNERT JOIN clause to SELECT statement
*/
func (q *Stmt) Join(table, on string) *Stmt {
q.join("JOIN ", table, on)
return q
}

/*
LeftJoin adds a LEFT OUTER JOIN clause to SELECT statement
*/
func (q *Stmt) LeftJoin(table, on string) *Stmt {
q.join("LEFT JOIN ", table, on)
return q
}

/*
RightJoin adds a RIGHT OUTER JOIN clause to SELECT statement
*/
func (q *Stmt) RightJoin(table, on string) *Stmt {
q.join("RIGHT JOIN ", table, on)
return q
}

/*
FullJoin adds a FULL OUTER JOIN clause to SELECT statement
*/
func (q *Stmt) FullJoin(table, on string) *Stmt {
q.join("FULL JOIN ", table, on)
return q
}

// OrderBy adds the ORDER BY clause to SELECT statement
func (q *Stmt) OrderBy(expr ...string) *Stmt {
q.addChunk(posOrderBy, "ORDER BY", strings.Join(expr, ", "), nil, ", ")
Expand Down Expand Up @@ -555,6 +592,22 @@ func (q *Stmt) Clone() *Stmt {
return stmt
}

// join adds a join clause to a SELECT statement
func (q *Stmt) join(joinType, table, on string) (index int) {
buf := bytebufferpool.Get()
buf.WriteString(joinType)
buf.WriteString(table)
buf.Write(joinOn)
buf.WriteString(on)
buf.WriteByte(')')

index = q.addChunk(posFrom, "", bufToString(&buf.B), nil, " ")

bytebufferpool.Put(buf)

return index
}

// addChunk adds a clause or expression to a statement.
func (q *Stmt) addChunk(pos int, clause, expr string, args []interface{}, sep string) (index int) {
// Remember the position
Expand Down Expand Up @@ -653,6 +706,7 @@ var (
space = []byte{' '}
placeholder = []byte{'?'}
placeholderComma = []byte{'?', ','}
joinOn = []byte{' ', 'O', 'N', ' ', '('}
)

const (
Expand Down
24 changes: 24 additions & 0 deletions stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,27 @@ func TestClone(t *testing.T) {
assert.NotEqual(t, q.Args(), q2.Args())
assert.NotEqual(t, q.String(), q2.String())
}

func TestJoin(t *testing.T) {
q := sqlf.From("orders o").Select("id").Join("users u", "u.id = o.user_id")
defer q.Close()
assert.Equal(t, "SELECT id FROM orders o JOIN users u ON (u.id = o.user_id)", q.String())
}

func TestLeftJoin(t *testing.T) {
q := sqlf.From("orders o").Select("id").LeftJoin("users u", "u.id = o.user_id")
defer q.Close()
assert.Equal(t, "SELECT id FROM orders o LEFT JOIN users u ON (u.id = o.user_id)", q.String())
}

func TestRightJoin(t *testing.T) {
q := sqlf.From("orders o").Select("id").RightJoin("users u", "u.id = o.user_id")
defer q.Close()
assert.Equal(t, "SELECT id FROM orders o RIGHT JOIN users u ON (u.id = o.user_id)", q.String())
}

func TestFullJoin(t *testing.T) {
q := sqlf.From("orders o").Select("id").FullJoin("users u", "u.id = o.user_id")
defer q.Close()
assert.Equal(t, "SELECT id FROM orders o FULL JOIN users u ON (u.id = o.user_id)", q.String())
}

0 comments on commit 5a17b38

Please sign in to comment.