From 357bede0d060a364a3693dc1a15d45824d9e6c0e Mon Sep 17 00:00:00 2001 From: "joshringuk@gmail.com" Date: Thu, 22 Feb 2024 11:17:20 +0000 Subject: [PATCH] updates from pr --- cte.go | 130 +++++++++++++++++++++++++++++++++++++++++++++++++ cte_test.go | 135 +++++++++++++++++++++++++++++++++++++++++++++++++++ expr.go | 55 ++++++++++++++++----- statement.go | 19 ++++++++ 4 files changed, 326 insertions(+), 13 deletions(-) create mode 100644 cte.go create mode 100644 cte_test.go diff --git a/cte.go b/cte.go new file mode 100644 index 00000000..988ccca1 --- /dev/null +++ b/cte.go @@ -0,0 +1,130 @@ +package squirrel + +import ( + "bytes" + "fmt" + + "github.com/lann/builder" +) + +// Common Table Expressions helper +// e.g. +// WITH cte AS ( +// ... +// ), cte_2 AS ( +// ... +// ) +// SELECT ... FROM cte ... cte_2; + +type commonTableExpressionsData struct { + PlaceholderFormat PlaceholderFormat + Recursive bool + CurrentCteName string + Ctes []Sqlizer + Statement Sqlizer +} + +func (d *commonTableExpressionsData) toSql() (sqlStr string, args []interface{}, err error) { + if len(d.Ctes) == 0 { + err = fmt.Errorf("common table expressions statements must have at least one label and subquery") + return + } + + if d.Statement == nil { + err = fmt.Errorf("common table expressions must one of the following final statement: (select, insert, replace, update, delete)") + return + } + + sql := &bytes.Buffer{} + + sql.WriteString("WITH ") + if d.Recursive { + sql.WriteString("RECURSIVE ") + } + + args, err = appendToSql(d.Ctes, sql, ", ", args) + sql.WriteString("\n") + args, err = appendToSql([]Sqlizer{d.Statement}, sql, "", args) + + sqlStr = sql.String() + return +} + +func (d *commonTableExpressionsData) ToSql() (sql string, args []interface{}, err error) { + return d.toSql() +} + +// Builder + +// CommonTableExpressionsBuilder builds CTE (Common Table Expressions) SQL statements. +type CommonTableExpressionsBuilder builder.Builder + +func init() { + builder.Register(CommonTableExpressionsBuilder{}, commonTableExpressionsData{}) +} + +// Format methods + +// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the +// query. +func (b CommonTableExpressionsBuilder) PlaceholderFormat(f PlaceholderFormat) CommonTableExpressionsBuilder { + return builder.Set(b, "PlaceholderFormat", f).(CommonTableExpressionsBuilder) +} + +// SQL methods + +// ToSql builds the query into a SQL string and bound args. +func (b CommonTableExpressionsBuilder) ToSql() (string, []interface{}, error) { + data := builder.GetStruct(b).(commonTableExpressionsData) + return data.ToSql() +} + +// MustSql builds the query into a SQL string and bound args. +// It panics if there are any errors. +func (b CommonTableExpressionsBuilder) MustSql() (string, []interface{}) { + sql, args, err := b.ToSql() + if err != nil { + panic(err) + } + return sql, args +} + +func (b CommonTableExpressionsBuilder) Recursive(recursive bool) CommonTableExpressionsBuilder { + return builder.Set(b, "Recursive", recursive).(CommonTableExpressionsBuilder) +} + +// Cte starts a new cte +func (b CommonTableExpressionsBuilder) Cte(cte string) CommonTableExpressionsBuilder { + return builder.Set(b, "CurrentCteName", cte).(CommonTableExpressionsBuilder) +} + +// As sets the expression for the Cte +func (b CommonTableExpressionsBuilder) As(as SelectBuilder) CommonTableExpressionsBuilder { + data := builder.GetStruct(b).(commonTableExpressionsData) + return builder.Append(b, "Ctes", cteExpr{as, data.CurrentCteName}).(CommonTableExpressionsBuilder) +} + +// Select finalizes the CommonTableExpressionsBuilder with a SELECT +func (b CommonTableExpressionsBuilder) Select(statement SelectBuilder) CommonTableExpressionsBuilder { + return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) +} + +// Insert finalizes the CommonTableExpressionsBuilder with an INSERT +func (b CommonTableExpressionsBuilder) Insert(statement InsertBuilder) CommonTableExpressionsBuilder { + return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) +} + +// Replace finalizes the CommonTableExpressionsBuilder with a REPLACE +func (b CommonTableExpressionsBuilder) Replace(statement InsertBuilder) CommonTableExpressionsBuilder { + return b.Insert(statement) +} + +// Update finalizes the CommonTableExpressionsBuilder with an UPDATE +func (b CommonTableExpressionsBuilder) Update(statement UpdateBuilder) CommonTableExpressionsBuilder { + return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) +} + +// Delete finalizes the CommonTableExpressionsBuilder with a DELETE +func (b CommonTableExpressionsBuilder) Delete(statement DeleteBuilder) CommonTableExpressionsBuilder { + return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) +} diff --git a/cte_test.go b/cte_test.go new file mode 100644 index 00000000..986bd4e0 --- /dev/null +++ b/cte_test.go @@ -0,0 +1,135 @@ +package squirrel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithAsQuery_OneSubquery(t *testing.T) { + w := With("lab").As( + Select("col").From("tab"). + Where("simple"). + Where("NOT hard"), + ).Select( + Select("col"). + From("lab"), + ) + q, _, err := w.ToSql() + assert.NoError(t, err) + + expectedSql := "WITH lab AS (\n" + + "SELECT col FROM tab WHERE simple AND NOT hard\n" + + ")\n" + + "SELECT col FROM lab" + assert.Equal(t, expectedSql, q) + + w = WithRecursive("lab").As( + Select("col").From("tab"). + Where("simple"). + Where("NOT hard"), + ).Select(Select("col"). + From("lab"), + ) + q, _, err = w.ToSql() + assert.NoError(t, err) + + expectedSql = "WITH RECURSIVE lab AS (\n" + + "SELECT col FROM tab WHERE simple AND NOT hard\n" + + ")\n" + + "SELECT col FROM lab" + assert.Equal(t, expectedSql, q) +} + +func TestWithAsQuery_TwoSubqueries(t *testing.T) { + w := With("lab_1").As( + Select("col_1", "col_common").From("tab_1"). + Where("simple"). + Where("NOT hard"), + ).Cte("lab_2").As( + Select("col_2", "col_common").From("tab_2"), + ).Select(Select("col_1", "col_2", "col_common"). + From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common"), + ) + q, _, err := w.ToSql() + assert.NoError(t, err) + + expectedSql := "WITH lab_1 AS (\n" + + "SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard\n" + + "), lab_2 AS (\n" + + "SELECT col_2, col_common FROM tab_2\n" + + ")\n" + + "SELECT col_1, col_2, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common" + assert.Equal(t, expectedSql, q) +} + +func TestWithAsQuery_ManySubqueries(t *testing.T) { + w := With("lab_1").As( + Select("col_1", "col_common").From("tab_1"). + Where("simple"). + Where("NOT hard"), + ).Cte("lab_2").As( + Select("col_2", "col_common").From("tab_2"), + ).Cte("lab_3").As( + Select("col_3", "col_common").From("tab_3"), + ).Cte("lab_4").As( + Select("col_4", "col_common").From("tab_4"), + ).Select( + Select("col_1", "col_2", "col_3", "col_4", "col_common"). + From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common"). + Join("lab_3 ON lab_1.col_common = lab_3.col_common"). + Join("lab_4 ON lab_1.col_common = lab_4.col_common"), + ) + q, _, err := w.ToSql() + assert.NoError(t, err) + + expectedSql := "WITH lab_1 AS (\n" + + "SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard\n" + + "), lab_2 AS (\n" + + "SELECT col_2, col_common FROM tab_2\n" + + "), lab_3 AS (\n" + + "SELECT col_3, col_common FROM tab_3\n" + + "), lab_4 AS (\n" + + "SELECT col_4, col_common FROM tab_4\n" + + ")\n" + + "SELECT col_1, col_2, col_3, col_4, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common JOIN lab_3 ON lab_1.col_common = lab_3.col_common JOIN lab_4 ON lab_1.col_common = lab_4.col_common" + assert.Equal(t, expectedSql, q) +} + +func TestWithAsQuery_Insert(t *testing.T) { + w := With("lab").As( + Select("col").From("tab"). + Where("simple"). + Where("NOT hard"), + ).Insert(Insert("ins_tab").Columns("ins_col").Select(Select("col").From("lab"))) + q, _, err := w.ToSql() + assert.NoError(t, err) + + expectedSql := "WITH lab AS (\n" + + "SELECT col FROM tab WHERE simple AND NOT hard\n" + + ")\n" + + "INSERT INTO ins_tab (ins_col) SELECT col FROM lab" + assert.Equal(t, expectedSql, q) +} + +func TestWithAsQuery_Update(t *testing.T) { + w := With("lab").As( + Select("col", "common_col").From("tab"). + Where("simple"). + Where("NOT hard"), + ).Update( + Update("upd_tab, lab"). + Set("upd_col", Expr("lab.col")). + Where("common_col = lab.common_col"), + ) + + q, _, err := w.ToSql() + assert.NoError(t, err) + + expectedSql := "WITH lab AS (\n" + + "SELECT col, common_col FROM tab WHERE simple AND NOT hard\n" + + ")\n" + + "UPDATE upd_tab, lab SET upd_col = lab.col WHERE common_col = lab.common_col" + + assert.Equal(t, expectedSql, q) +} diff --git a/expr.go b/expr.go index eba1b457..8a1e9691 100644 --- a/expr.go +++ b/expr.go @@ -23,7 +23,8 @@ type expr struct { // Expr builds an expression from a SQL fragment and arguments. // // Ex: -// Expr("FROM_UNIXTIME(?)", t) +// +// Expr("FROM_UNIXTIME(?)", t) func Expr(sql string, args ...interface{}) Sqlizer { return expr{sql: sql, args: args} } @@ -105,8 +106,9 @@ func (ce concatExpr) ToSql() (sql string, args []interface{}, err error) { // ConcatExpr builds an expression by concatenating strings and other expressions. // // Ex: -// name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName) -// ConcatExpr("COALESCE(full_name,", name_expr, ")") +// +// name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName) +// ConcatExpr("COALESCE(full_name,", name_expr, ")") func ConcatExpr(parts ...interface{}) concatExpr { return concatExpr(parts) } @@ -120,7 +122,8 @@ type aliasExpr struct { // Alias allows to define alias for column in SelectBuilder. Useful when column is // defined as complex expression like IF or CASE // Ex: -// .Column(Alias(caseStmt, "case_column")) +// +// .Column(Alias(caseStmt, "case_column")) func Alias(expr Sqlizer, alias string) aliasExpr { return aliasExpr{expr, alias} } @@ -212,7 +215,8 @@ func (eq Eq) ToSql() (sql string, args []interface{}, err error) { // NotEq is syntactic sugar for use with Where/Having/Set methods. // Ex: -// .Where(NotEq{"id": 1}) == "id <> 1" +// +// .Where(NotEq{"id": 1}) == "id <> 1" type NotEq Eq func (neq NotEq) ToSql() (sql string, args []interface{}, err error) { @@ -221,7 +225,8 @@ func (neq NotEq) ToSql() (sql string, args []interface{}, err error) { // Like is syntactic sugar for use with LIKE conditions. // Ex: -// .Where(Like{"name": "%irrel"}) +// +// .Where(Like{"name": "%irrel"}) type Like map[string]interface{} func (lk Like) toSql(opr string) (sql string, args []interface{}, err error) { @@ -260,7 +265,8 @@ func (lk Like) ToSql() (sql string, args []interface{}, err error) { // NotLike is syntactic sugar for use with LIKE conditions. // Ex: -// .Where(NotLike{"name": "%irrel"}) +// +// .Where(NotLike{"name": "%irrel"}) type NotLike Like func (nlk NotLike) ToSql() (sql string, args []interface{}, err error) { @@ -269,7 +275,8 @@ func (nlk NotLike) ToSql() (sql string, args []interface{}, err error) { // ILike is syntactic sugar for use with ILIKE conditions. // Ex: -// .Where(ILike{"name": "sq%"}) +// +// .Where(ILike{"name": "sq%"}) type ILike Like func (ilk ILike) ToSql() (sql string, args []interface{}, err error) { @@ -278,7 +285,8 @@ func (ilk ILike) ToSql() (sql string, args []interface{}, err error) { // NotILike is syntactic sugar for use with ILIKE conditions. // Ex: -// .Where(NotILike{"name": "sq%"}) +// +// .Where(NotILike{"name": "sq%"}) type NotILike Like func (nilk NotILike) ToSql() (sql string, args []interface{}, err error) { @@ -287,7 +295,8 @@ func (nilk NotILike) ToSql() (sql string, args []interface{}, err error) { // Lt is syntactic sugar for use with Where/Having/Set methods. // Ex: -// .Where(Lt{"id": 1}) +// +// .Where(Lt{"id": 1}) type Lt map[string]interface{} func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) { @@ -339,7 +348,8 @@ func (lt Lt) ToSql() (sql string, args []interface{}, err error) { // LtOrEq is syntactic sugar for use with Where/Having/Set methods. // Ex: -// .Where(LtOrEq{"id": 1}) == "id <= 1" +// +// .Where(LtOrEq{"id": 1}) == "id <= 1" type LtOrEq Lt func (ltOrEq LtOrEq) ToSql() (sql string, args []interface{}, err error) { @@ -348,7 +358,8 @@ func (ltOrEq LtOrEq) ToSql() (sql string, args []interface{}, err error) { // Gt is syntactic sugar for use with Where/Having/Set methods. // Ex: -// .Where(Gt{"id": 1}) == "id > 1" +// +// .Where(Gt{"id": 1}) == "id > 1" type Gt Lt func (gt Gt) ToSql() (sql string, args []interface{}, err error) { @@ -357,7 +368,8 @@ func (gt Gt) ToSql() (sql string, args []interface{}, err error) { // GtOrEq is syntactic sugar for use with Where/Having/Set methods. // Ex: -// .Where(GtOrEq{"id": 1}) == "id >= 1" +// +// .Where(GtOrEq{"id": 1}) == "id >= 1" type GtOrEq Lt func (gtOrEq GtOrEq) ToSql() (sql string, args []interface{}, err error) { @@ -417,3 +429,20 @@ func isListType(val interface{}) bool { valVal := reflect.ValueOf(val) return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice } + +type cteExpr struct { + expr Sqlizer + cte string +} + +func Cte(expr Sqlizer, cte string) cteExpr { + return cteExpr{expr, cte} +} + +func (e cteExpr) ToSql() (sql string, args []interface{}, err error) { + sql, args, err = e.expr.ToSql() + if err == nil { + sql = fmt.Sprintf("%s AS (\n%s\n)", e.cte, sql) + } + return +} diff --git a/statement.go b/statement.go index 9420c67f..400d7555 100644 --- a/statement.go +++ b/statement.go @@ -31,6 +31,11 @@ func (b StatementBuilderType) Delete(from string) DeleteBuilder { return DeleteBuilder(b).From(from) } +// With returns a CommonTableExpressionsBuilder for this StatementBuilderType +func (b StatementBuilderType) With(cte string) CommonTableExpressionsBuilder { + return CommonTableExpressionsBuilder(b).Cte(cte) +} + // PlaceholderFormat sets the PlaceholderFormat field for any child builders. func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType { return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType) @@ -87,6 +92,20 @@ func Delete(from string) DeleteBuilder { return StatementBuilder.Delete(from) } +// With returns a new CommonTableExpressionsBuilder with the given first cte name +// +// See CommonTableExpressionsBuilder.Cte +func With(cte string) CommonTableExpressionsBuilder { + return StatementBuilder.With(cte) +} + +// WithRecursive returns a new CommonTableExpressionsBuilder with the RECURSIVE option and the given first cte name +// +// See CommonTableExpressionsBuilder.Cte, CommonTableExpressionsBuilder.Recursive +func WithRecursive(cte string) CommonTableExpressionsBuilder { + return StatementBuilder.With(cte).Recursive(true) +} + // Case returns a new CaseBuilder // "what" represents case value func Case(what ...interface{}) CaseBuilder {