Skip to content

Commit

Permalink
chore(embedded/sql): add support for LEFT JOIN
Browse files Browse the repository at this point in the history
Signed-off-by: Stefano Scafiti <[email protected]>
  • Loading branch information
ostafen committed Nov 29, 2024
1 parent ae9af09 commit a2057db
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 45 deletions.
61 changes: 53 additions & 8 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5986,7 +5986,7 @@ func TestNestedJoins(t *testing.T) {
require.NoError(t, err)
}

func TestLeftRightJoins(t *testing.T) {
func TestLeftJoins(t *testing.T) {
e := setupCommonTest(t)

_, _, err := e.Exec(
Expand Down Expand Up @@ -6055,10 +6055,28 @@ func TestLeftRightJoins(t *testing.T) {
)
require.NoError(t, err)

rows, err := e.queryAll(
context.Background(),
nil,
`SELECT
assertQueryShouldProduceResults(
t,
e,
`SELECT c.customer_id, c.customer_name, c.email, o.order_id, o.order_date
FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id
ORDER BY c.customer_id, o.order_date;`,
`
SELECT *
FROM (
VALUES
(1, 'Alice Johnson', '[email protected]', 101, '2024-11-01'::TIMESTAMP),
(1, 'Alice Johnson', '[email protected]', 103, '2024-11-03'::TIMESTAMP),
(2, 'Bob Smith', '[email protected]', 102, '2024-11-02'::TIMESTAMP),
(3, 'Charlie Brown', '[email protected]', NULL, NULL)
)`,
)

assertQueryShouldProduceResults(
t,
e,
`
SELECT
c.customer_name,
c.email,
o.order_id,
Expand All @@ -6073,10 +6091,16 @@ func TestLeftRightJoins(t *testing.T) {
LEFT JOIN orders o ON oi.order_id = o.order_id
LEFT JOIN customers c ON o.customer_id = c.customer_id
ORDER BY o.order_date, c.customer_name;`,
nil,
`
SELECT *
FROM (
VALUES
('Alice Johnson', '[email protected]', 101, '2024-11-01'::TIMESTAMP, 'Laptop', 2, 1200.00, 2400.00),
('Alice Johnson', '[email protected]', 101, '2024-11-01'::TIMESTAMP, 'Smartphone', 1, 800.00, 800.00),
('Bob Smith', '[email protected]', 102, '2024-11-02'::TIMESTAMP, 'Tablet', 3, 400.00, 1200.00),
('Alice Johnson', '[email protected]', 103, '2024-11-03'::TIMESTAMP, 'Smartphone', 2, 800.00, 1600.00)
)`,
)
require.NoError(t, err)
require.Len(t, rows, 4)
}

func TestReOpening(t *testing.T) {
Expand Down Expand Up @@ -9527,3 +9551,24 @@ func TestFunctions(t *testing.T) {
require.Equal(t, "OBJECT", rows[0].ValuesByPosition[0].RawValue().(string))
})
}

func assertQueryShouldProduceResults(t *testing.T, e *Engine, query, resultQuery string) {
queryReader, err := e.Query(context.Background(), nil, query, nil)
require.NoError(t, err)
defer queryReader.Close()

resultReader, err := e.Query(context.Background(), nil, resultQuery, nil)
require.NoError(t, err)
defer resultReader.Close()

for {
row, err := queryReader.Read(context.Background())
row1, err1 := resultReader.Read(context.Background())
require.Equal(t, err, err1)

if errors.Is(err, ErrNoMoreRows) {
break
}
require.Equal(t, row1.ValuesByPosition, row.ValuesByPosition)
}
}
2 changes: 0 additions & 2 deletions embedded/sql/joint_row_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ func newJointRowReader(rowReader RowReader, joins []*JoinSpec) (*jointRowReader,
return nil, ErrIllegalArguments
}

// Sanity check: Ensure that no RIGHT JOINs are specified,
// as we assume all RIGHT JOINs to be translated into equivalent LEFT JOINs.
for _, jspec := range joins {
if jspec.joinType == RightJoin {
return nil, ErrUnsupportedJoinType
Expand Down
35 changes: 0 additions & 35 deletions embedded/sql/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3409,44 +3409,9 @@ func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[strin
rowReader = newLimitRowReader(rowReader, limit)
}
}

return rowReader, nil
}

// removeRightJoin converts all right joins in the SelectStmt to left joins by swapping the involved data sources.
func (stmt *SelectStmt) removeRightJoin() {
if len(stmt.joins) == 0 {
return
}

newJoins := make([]*JoinSpec, len(stmt.joins)+1)

start := 0
end := len(newJoins) - 1

for i := len(stmt.joins) - 1; i > 0; i-- {
jspec := stmt.joins[len(stmt.joins)-1-i]

if jspec.joinType == RightJoin {
newJoins[start] = jspec
newJoins[start].joinType = LeftJoin
start++
} else {
newJoins[end] = jspec
end--
}
}

newJoins[start] = &JoinSpec{ds: stmt.ds}
if start == 0 {
stmt.joins = newJoins[1:]
return
}

for i := start; i > 0; i-- {
}
}

func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {
Expand Down

0 comments on commit a2057db

Please sign in to comment.