Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding common table expression support #374

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions cte.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package squirrel

import (
"bytes"
"fmt"

"github.com/lann/builder"
)

// Common Table Expressions helper
// e.g.
// WITH cte AS (
// ...
// ), cte_2 AS (
// ...
// )
// SELECT ... FROM cte ... cte_2;

type commonTableExpressionsData struct {
PlaceholderFormat PlaceholderFormat
Recursive bool
CurrentCteName string
Ctes []Sqlizer
Statement Sqlizer
}

func (d *commonTableExpressionsData) toSql() (sqlStr string, args []interface{}, err error) {
if len(d.Ctes) == 0 {
err = fmt.Errorf("common table expressions statements must have at least one label and subquery")
return
}

if d.Statement == nil {
err = fmt.Errorf("common table expressions must one of the following final statement: (select, insert, replace, update, delete)")
return
}

sql := &bytes.Buffer{}

sql.WriteString("WITH ")
if d.Recursive {
sql.WriteString("RECURSIVE ")
}

args, err = appendToSql(d.Ctes, sql, ", ", args)
sql.WriteString("\n")
args, err = appendToSql([]Sqlizer{d.Statement}, sql, "", args)

sqlStr = sql.String()
return
}

func (d *commonTableExpressionsData) ToSql() (sql string, args []interface{}, err error) {
return d.toSql()
}

// Builder

// CommonTableExpressionsBuilder builds CTE (Common Table Expressions) SQL statements.
type CommonTableExpressionsBuilder builder.Builder

func init() {
builder.Register(CommonTableExpressionsBuilder{}, commonTableExpressionsData{})
}

// Format methods

// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b CommonTableExpressionsBuilder) PlaceholderFormat(f PlaceholderFormat) CommonTableExpressionsBuilder {
return builder.Set(b, "PlaceholderFormat", f).(CommonTableExpressionsBuilder)
}

// SQL methods

// ToSql builds the query into a SQL string and bound args.
func (b CommonTableExpressionsBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(commonTableExpressionsData)
return data.ToSql()
}

// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b CommonTableExpressionsBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}

func (b CommonTableExpressionsBuilder) Recursive(recursive bool) CommonTableExpressionsBuilder {
return builder.Set(b, "Recursive", recursive).(CommonTableExpressionsBuilder)
}

// Cte starts a new cte
func (b CommonTableExpressionsBuilder) Cte(cte string) CommonTableExpressionsBuilder {
return builder.Set(b, "CurrentCteName", cte).(CommonTableExpressionsBuilder)
}

// As sets the expression for the Cte
func (b CommonTableExpressionsBuilder) As(as SelectBuilder) CommonTableExpressionsBuilder {
data := builder.GetStruct(b).(commonTableExpressionsData)
return builder.Append(b, "Ctes", cteExpr{as, data.CurrentCteName}).(CommonTableExpressionsBuilder)
}

// Select finalizes the CommonTableExpressionsBuilder with a SELECT
func (b CommonTableExpressionsBuilder) Select(statement SelectBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}

// Insert finalizes the CommonTableExpressionsBuilder with an INSERT
func (b CommonTableExpressionsBuilder) Insert(statement InsertBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}

// Replace finalizes the CommonTableExpressionsBuilder with a REPLACE
func (b CommonTableExpressionsBuilder) Replace(statement InsertBuilder) CommonTableExpressionsBuilder {
return b.Insert(statement)
}

// Update finalizes the CommonTableExpressionsBuilder with an UPDATE
func (b CommonTableExpressionsBuilder) Update(statement UpdateBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}

// Delete finalizes the CommonTableExpressionsBuilder with a DELETE
func (b CommonTableExpressionsBuilder) Delete(statement DeleteBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}
135 changes: 135 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package squirrel

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestWithAsQuery_OneSubquery(t *testing.T) {
w := With("lab").As(
Select("col").From("tab").
Where("simple").
Where("NOT hard"),
).Select(
Select("col").
From("lab"),
)
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab AS (\n" +
"SELECT col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"SELECT col FROM lab"
assert.Equal(t, expectedSql, q)

w = WithRecursive("lab").As(
Select("col").From("tab").
Where("simple").
Where("NOT hard"),
).Select(Select("col").
From("lab"),
)
q, _, err = w.ToSql()
assert.NoError(t, err)

expectedSql = "WITH RECURSIVE lab AS (\n" +
"SELECT col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"SELECT col FROM lab"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_TwoSubqueries(t *testing.T) {
w := With("lab_1").As(
Select("col_1", "col_common").From("tab_1").
Where("simple").
Where("NOT hard"),
).Cte("lab_2").As(
Select("col_2", "col_common").From("tab_2"),
).Select(Select("col_1", "col_2", "col_common").
From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common"),
)
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab_1 AS (\n" +
"SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard\n" +
"), lab_2 AS (\n" +
"SELECT col_2, col_common FROM tab_2\n" +
")\n" +
"SELECT col_1, col_2, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_ManySubqueries(t *testing.T) {
w := With("lab_1").As(
Select("col_1", "col_common").From("tab_1").
Where("simple").
Where("NOT hard"),
).Cte("lab_2").As(
Select("col_2", "col_common").From("tab_2"),
).Cte("lab_3").As(
Select("col_3", "col_common").From("tab_3"),
).Cte("lab_4").As(
Select("col_4", "col_common").From("tab_4"),
).Select(
Select("col_1", "col_2", "col_3", "col_4", "col_common").
From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common").
Join("lab_3 ON lab_1.col_common = lab_3.col_common").
Join("lab_4 ON lab_1.col_common = lab_4.col_common"),
)
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab_1 AS (\n" +
"SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard\n" +
"), lab_2 AS (\n" +
"SELECT col_2, col_common FROM tab_2\n" +
"), lab_3 AS (\n" +
"SELECT col_3, col_common FROM tab_3\n" +
"), lab_4 AS (\n" +
"SELECT col_4, col_common FROM tab_4\n" +
")\n" +
"SELECT col_1, col_2, col_3, col_4, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common JOIN lab_3 ON lab_1.col_common = lab_3.col_common JOIN lab_4 ON lab_1.col_common = lab_4.col_common"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_Insert(t *testing.T) {
w := With("lab").As(
Select("col").From("tab").
Where("simple").
Where("NOT hard"),
).Insert(Insert("ins_tab").Columns("ins_col").Select(Select("col").From("lab")))
q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab AS (\n" +
"SELECT col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"INSERT INTO ins_tab (ins_col) SELECT col FROM lab"
assert.Equal(t, expectedSql, q)
}

func TestWithAsQuery_Update(t *testing.T) {
w := With("lab").As(
Select("col", "common_col").From("tab").
Where("simple").
Where("NOT hard"),
).Update(
Update("upd_tab, lab").
Set("upd_col", Expr("lab.col")).
Where("common_col = lab.common_col"),
)

q, _, err := w.ToSql()
assert.NoError(t, err)

expectedSql := "WITH lab AS (\n" +
"SELECT col, common_col FROM tab WHERE simple AND NOT hard\n" +
")\n" +
"UPDATE upd_tab, lab SET upd_col = lab.col WHERE common_col = lab.common_col"

assert.Equal(t, expectedSql, q)
}
Loading