Skip to content

Commit

Permalink
chore(embedded/sql): Implement CASE statement
Browse files Browse the repository at this point in the history
Signed-off-by: Stefano Scafiti <[email protected]>
  • Loading branch information
ostafen committed Nov 27, 2024
1 parent e77545f commit 51c0742
Show file tree
Hide file tree
Showing 8 changed files with 945 additions and 298 deletions.
158 changes: 158 additions & 0 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,164 @@ func TestQuery(t *testing.T) {
}
})

t.Run("query with case when then", func(t *testing.T) {
_, _, err := engine.Exec(
context.Background(),
nil,
`CREATE TABLE employees (
employee_id INTEGER AUTO_INCREMENT,
first_name VARCHAR[50],
last_name VARCHAR[50],
department VARCHAR[50],
salary INTEGER,
hire_date TIMESTAMP,
job_title VARCHAR[50],
PRIMARY KEY employee_id
);`,
nil,
)
require.NoError(t, err)

n := 100
for i := 0; i < n; i++ {
_, _, err := engine.Exec(
context.Background(),
nil,
`INSERT INTO employees(first_name, last_name, department, salary, job_title)
VALUES (@first_name, @last_name, @department, @salary, @job_title)
`,
map[string]interface{}{
"first_name": fmt.Sprintf("name%d", i),
"last_name": fmt.Sprintf("surname%d", i),
"department": []string{"sales", "manager", "engineering"}[rand.Intn(3)],
"salary": []int64{20, 40, 50, 80, 100}[rand.Intn(5)] * 1000,
"job_title": []string{"manager", "senior engineer", "executive"}[rand.Intn(3)],
},
)
require.NoError(t, err)
}

_, err = engine.queryAll(
context.Background(),
nil,
"SELECT CASE WHEN salary THEN 0 END FROM employees",
nil,
)
require.ErrorIs(t, err, ErrInvalidTypes)

rows, err := engine.queryAll(
context.Background(),
nil,
`SELECT
employee_id,
first_name,
last_name,
salary,
CASE
WHEN salary < 50000 THEN @low
WHEN salary >= 50000 AND salary <= 100000 THEN @medium
ELSE @high
END AS salary_category
FROM employees;`,
map[string]interface{}{
"low": "Low",
"medium": "Medium",
"high": "High",
},
)
require.NoError(t, err)
require.Len(t, rows, n)

for _, row := range rows {
salary := row.ValuesByPosition[3].RawValue().(int64)
category, _ := row.ValuesByPosition[4].RawValue().(string)

expectedCategory := "High"
if salary < 50000 {
expectedCategory = "Low"
} else if salary >= 50000 && salary <= 100000 {
expectedCategory = "Medium"
}
require.Equal(t, expectedCategory, category)
}

rows, err = engine.queryAll(
context.Background(),
nil,
`SELECT
department,
job_title,
CASE
WHEN department = 'sales' THEN
CASE
WHEN job_title = 'manager' THEN '20% Bonus'
ELSE '10% Bonus'
END
WHEN department = 'engineering' THEN
CASE
WHEN job_title = 'senior engineer' THEN '15% Bonus'
ELSE '5% Bonus'
END
ELSE
CASE
WHEN job_title = 'executive' THEN '12% Bonus'
ELSE 'No Bonus'
END
END AS bonus
FROM employees;`,
nil,
)
require.NoError(t, err)
require.Len(t, rows, n)

for _, row := range rows {
department := row.ValuesByPosition[0].RawValue().(string)
job, _ := row.ValuesByPosition[1].RawValue().(string)
bonus, _ := row.ValuesByPosition[2].RawValue().(string)

var expectedBonus string
switch department {
case "sales":
if job == "manager" {
expectedBonus = "20% Bonus"
} else {
expectedBonus = "10% Bonus"
}
case "engineering":
if job == "senior engineer" {
expectedBonus = "15% Bonus"
} else {
expectedBonus = "5% Bonus"
}
default:
if job == "executive" {
expectedBonus = "12% Bonus"
} else {
expectedBonus = "No Bonus"
}
}
require.Equal(t, expectedBonus, bonus)
}

rows, err = engine.queryAll(
context.Background(),
nil,
`SELECT
CASE
WHEN department = 'sales' THEN 'Sales Team'
END AS department
FROM employees
WHERE department != 'sales'
LIMIT 1
;`,
nil,
)
require.NoError(t, err)
require.Len(t, rows, 1)
require.Nil(t, rows[0].ValuesByPosition[0].RawValue())
})

t.Run("invalid queries", func(t *testing.T) {
r, err = engine.Query(context.Background(), nil, "INVALID QUERY", nil)
require.ErrorIs(t, err, ErrParsingError)
Expand Down
5 changes: 5 additions & 0 deletions embedded/sql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ var reservedWords = map[string]int{
"PRIVILEGES": PRIVILEGES,
"CHECK": CHECK,
"CONSTRAINT": CONSTRAINT,
"CASE": CASE,
"WHEN": WHEN,
"THEN": THEN,
"ELSE": ELSE,
"END": END,
}

var joinTypes = map[string]JoinType{
Expand Down
97 changes: 96 additions & 1 deletion embedded/sql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ func TestAggFnStmt(t *testing.T) {
}
}

func TestExpressions(t *testing.T) {
func TestParseExp(t *testing.T) {
testCases := []struct {
input string
expectedOutput []SQLStmt
Expand Down Expand Up @@ -1674,6 +1674,98 @@ func TestExpressions(t *testing.T) {
}},
expectedError: nil,
},
{
input: "SELECT CASE WHEN is_deleted OR is_expired THEN 1 END AS is_deleted_or_expired FROM my_table",
expectedOutput: []SQLStmt{
&SelectStmt{
ds: &tableRef{table: "my_table"},
targets: []TargetEntry{
{
Exp: &CaseWhenExp{
whenThen: []whenThenClause{
{
when: &BinBoolExp{
op: OR,
left: &ColSelector{col: "is_deleted"},
right: &ColSelector{col: "is_expired"},
},
then: &Integer{1},
},
},
},
As: "is_deleted_or_expired",
},
},
},
},
},
{
input: "SELECT CASE WHEN is_active THEN 1 ELSE 2 END FROM my_table",
expectedOutput: []SQLStmt{
&SelectStmt{
ds: &tableRef{table: "my_table"},
targets: []TargetEntry{
{
Exp: &CaseWhenExp{
whenThen: []whenThenClause{
{
when: &ColSelector{col: "is_active"},
then: &Integer{1},
},
},
elseExp: &Integer{2},
},
},
},
},
},
},
{
input: `
SELECT product_name,
CASE
WHEN stock < 10 THEN 'Low stock'
WHEN stock >= 10 AND stock <= 50 THEN 'Medium stock'
WHEN stock > 50 THEN 'High stock'
ELSE 'Out of stock'
END AS stock_status
FROM products
`,
expectedOutput: []SQLStmt{
&SelectStmt{
ds: &tableRef{table: "products"},
targets: []TargetEntry{
{
Exp: &ColSelector{col: "product_name"},
},
{
Exp: &CaseWhenExp{
whenThen: []whenThenClause{
{
when: &CmpBoolExp{op: LT, left: &ColSelector{col: "stock"}, right: &Integer{10}},
then: &Varchar{"Low stock"},
},
{
when: &BinBoolExp{
op: AND,
left: &CmpBoolExp{op: GE, left: &ColSelector{col: "stock"}, right: &Integer{10}},
right: &CmpBoolExp{op: LE, left: &ColSelector{col: "stock"}, right: &Integer{50}},
},
then: &Varchar{"Medium stock"},
},
{
when: &CmpBoolExp{op: GT, left: &ColSelector{col: "stock"}, right: &Integer{50}},
then: &Varchar{"High stock"},
},
},
elseExp: &Varchar{"Out of stock"},
},
As: "stock_status",
},
},
},
},
},
}

for i, tc := range testCases {
Expand Down Expand Up @@ -1897,6 +1989,9 @@ func TestExprString(t *testing.T) {
"((col1 AND (col2 < 10)) OR (@param = 3 AND (col4 = TRUE))) AND NOT (col5 = 'value' OR (2 + 2 != 4))",
"CAST (func_call(1, 'two', 2.5) AS TIMESTAMP)",
"col IN (TRUE, 1, 'test', 1.5)",
"CASE WHEN in_stock THEN 'In Stock' END",
"CASE WHEN 1 > 0 THEN 1 ELSE 0 END",
"CASE WHEN is_active THEN 'active' WHEN is_expired THEN 'expired' ELSE 'active' END",
}

for i, e := range exps {
Expand Down
7 changes: 6 additions & 1 deletion embedded/sql/proj_row_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ func (pr *projectedRowReader) Read(ctx context.Context) (*Row, error) {
}

for i, t := range pr.targets {
v, err := t.Exp.reduce(pr.Tx(), row, pr.rowReader.TableAlias())
e, err := t.Exp.substitute(pr.Parameters())
if err != nil {
return nil, fmt.Errorf("%w: when evaluating WHERE clause", err)
}

v, err := e.reduce(pr.Tx(), row, pr.rowReader.TableAlias())
if err != nil {
return nil, err
}
Expand Down
53 changes: 50 additions & 3 deletions embedded/sql/sql_grammar.y
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ func setResult(l yyLexer, stmts []SQLStmt) {
permission Permission
sqlPrivilege SQLPrivilege
sqlPrivileges []SQLPrivilege
whenThenClauses []whenThenClause
}

%token CREATE DROP USE DATABASE USER WITH PASSWORD READ READWRITE ADMIN SNAPSHOT HISTORY SINCE AFTER BEFORE UNTIL TX OF TIMESTAMP
%token TABLE UNIQUE INDEX ON ALTER ADD RENAME TO COLUMN CONSTRAINT PRIMARY KEY CHECK GRANT REVOKE GRANTS FOR PRIVILEGES
%token BEGIN TRANSACTION COMMIT ROLLBACK
%token INSERT UPSERT INTO VALUES DELETE UPDATE SET CONFLICT DO NOTHING RETURNING
%token SELECT DISTINCT FROM JOIN HAVING WHERE GROUP BY LIMIT OFFSET ORDER ASC DESC AS UNION ALL
%token SELECT DISTINCT FROM JOIN HAVING WHERE GROUP BY LIMIT OFFSET ORDER ASC DESC AS UNION ALL CASE WHEN THEN ELSE END
%token NOT LIKE IF EXISTS IN IS
%token AUTO_INCREMENT NULL CAST SCAST
%token SHOW DATABASES TABLES USERS
Expand Down Expand Up @@ -135,10 +136,10 @@ func setResult(l yyLexer, stmts []SQLStmt) {
%type <join> join
%type <joinType> opt_join_type
%type <checks> opt_checks
%type <exp> exp opt_where opt_having boundexp
%type <exp> exp opt_where opt_having boundexp opt_else when_then_else
%type <binExp> binExp
%type <cols> opt_groupby
%type <exp> opt_limit opt_offset
%type <exp> opt_limit opt_offset case_when_exp
%type <targets> opt_targets targets
%type <integer> opt_max_len
%type <id> opt_as
Expand All @@ -152,6 +153,7 @@ func setResult(l yyLexer, stmts []SQLStmt) {
%type <permission> permission
%type <sqlPrivilege> sqlPrivilege
%type <sqlPrivileges> sqlPrivileges
%type <whenThenClauses> when_then_clauses

%start sql

Expand Down Expand Up @@ -1095,6 +1097,51 @@ exp:
{
$$ = &InListExp{val: $1, notIn: $2, values: $5}
}
|
case_when_exp
{
$$ = $1
}

case_when_exp:
CASE when_then_else END
{
$$ = $2
}
;

when_then_else:
when_then_clauses opt_else
{
$$ = &CaseWhenExp{
whenThen: $1,
elseExp: $2,
}
}
;

when_then_clauses:
WHEN exp THEN exp
{
$$ = []whenThenClause{{when: $2, then: $4}}
}
|
when_then_clauses WHEN exp THEN exp
{
$$ = append($1, whenThenClause{when: $3, then: $5})
}
;

opt_else:
{
$$ = nil
}
|
ELSE exp
{
$$ = $2
}
;

boundexp:
selector
Expand Down
Loading

0 comments on commit 51c0742

Please sign in to comment.