Skip to content

Commit 74f1176

Browse files
yeshengmgatorsmile
authored andcommitted
[SPARK-27815][SQL] Predicate pushdown in one pass for cascading joins
## What changes were proposed in this pull request? This PR makes the predicate pushdown logic in catalyst optimizer more efficient by unifying two existing rules `PushdownPredicates` and `PushPredicateThroughJoin`. Previously pushing down a predicate for queries such as `Filter(Join(Join(Join)))` requires n steps. This patch essentially reduces this to a single pass. To make this actually work, we need to unify a few rules such as `CombineFilters`, `PushDownPredicate` and `PushDownPrdicateThroughJoin`. Otherwise cases such as `Filter(Join(Filter(Join)))` still requires several passes to fully push down predicates. This unification is done by composing several partial functions, which makes a minimal code change and can reuse existing UTs. Results show that this optimization can improve the catalyst optimization time by 16.5%. For queries with more joins, the performance is even better. E.g., for TPC-DS q64, the performance boost is 49.2%. ## How was this patch tested? Existing UTs + new a UT for the new rule. Closes apache#24956 from yeshengm/fixed-point-opt. Authored-by: Yesheng Ma <[email protected]> Signed-off-by: gatorsmile <[email protected]>
1 parent 70b1a10 commit 74f1176

17 files changed

+235
-34
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

+24-6
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
6363
PushProjectionThroughUnion,
6464
ReorderJoin,
6565
EliminateOuterJoin,
66-
PushPredicateThroughJoin,
67-
PushDownPredicate,
66+
PushDownPredicates,
6867
PushDownLeftSemiAntiJoin,
6968
PushLeftSemiLeftAntiThroughJoin,
7069
LimitPushDown,
@@ -911,7 +910,9 @@ object CombineUnions extends Rule[LogicalPlan] {
911910
* one conjunctive predicate.
912911
*/
913912
object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
914-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
913+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
914+
915+
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
915916
// The query execution/optimization does not guarantee the expressions are evaluated in order.
916917
// We only can combine them if and only if both are deterministic.
917918
case Filter(fc, nf @ Filter(nc, grandChild)) if fc.deterministic && nc.deterministic =>
@@ -996,15 +997,30 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
996997
}
997998
}
998999

1000+
/**
1001+
* The unified version for predicate pushdown of normal operators and joins.
1002+
* This rule improves performance of predicate pushdown for cascading joins such as:
1003+
* Filter-Join-Join-Join. Most predicates can be pushed down in a single pass.
1004+
*/
1005+
object PushDownPredicates extends Rule[LogicalPlan] with PredicateHelper {
1006+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1007+
CombineFilters.applyLocally
1008+
.orElse(PushPredicateThroughNonJoin.applyLocally)
1009+
.orElse(PushPredicateThroughJoin.applyLocally)
1010+
}
1011+
}
1012+
9991013
/**
10001014
* Pushes [[Filter]] operators through many operators iff:
10011015
* 1) the operator is deterministic
10021016
* 2) the predicate is deterministic and the operator will not change any of rows.
10031017
*
10041018
* This heuristic is valid assuming the expression evaluation cost is minimal.
10051019
*/
1006-
object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
1007-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1020+
object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelper {
1021+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
1022+
1023+
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
10081024
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
10091025
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
10101026
// implies that, for a given input row, the output are determined by the expression's initial
@@ -1221,7 +1237,9 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
12211237
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
12221238
}
12231239

1224-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1240+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
1241+
1242+
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
12251243
// push the where condition down into join filter
12261244
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, hint)) =>
12271245
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.plans.logical._
2323
import org.apache.spark.sql.catalyst.rules.Rule
2424

2525
/**
26-
* This rule is a variant of [[PushDownPredicate]] which can handle
26+
* This rule is a variant of [[PushPredicateThroughNonJoin]] which can handle
2727
* pushing down Left semi and Left Anti joins below the following operators.
2828
* 1) Project
2929
* 2) Window
3030
* 3) Union
3131
* 4) Aggregate
32-
* 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]].
32+
* 5) Other permissible unary operators. please see [[PushPredicateThroughNonJoin.canPushThrough]].
3333
*/
3434
object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
3535
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -42,7 +42,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
4242
// No join condition, just push down the Join below Project
4343
p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
4444
} else {
45-
val aliasMap = PushDownPredicate.getAliasMap(p)
45+
val aliasMap = PushPredicateThroughNonJoin.getAliasMap(p)
4646
val newJoinCond = if (aliasMap.nonEmpty) {
4747
Option(replaceAlias(joinCond.get, aliasMap))
4848
} else {
@@ -55,7 +55,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
5555
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
5656
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
5757
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
58-
val aliasMap = PushDownPredicate.getAliasMap(agg)
58+
val aliasMap = PushPredicateThroughNonJoin.getAliasMap(agg)
5959
val canPushDownPredicate = (predicate: Expression) => {
6060
val replaced = replaceAlias(predicate, aliasMap)
6161
predicate.references.nonEmpty &&
@@ -94,7 +94,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
9494

9595
// LeftSemi/LeftAnti over UnaryNode
9696
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(_), _, _)
97-
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
97+
if PushPredicateThroughNonJoin.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
9898
val validAttrs = u.child.outputSet ++ rightOp.outputSet
9999
pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And))
100100
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ColumnPruningSuite extends PlanTest {
3232

3333
object Optimize extends RuleExecutor[LogicalPlan] {
3434
val batches = Batch("Column pruning", FixedPoint(100),
35-
PushDownPredicate,
35+
PushPredicateThroughNonJoin,
3636
ColumnPruning,
3737
RemoveNoopOperators,
3838
CollapseProject) :: Nil
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
23+
import org.apache.spark.sql.catalyst.plans._
24+
import org.apache.spark.sql.catalyst.plans.logical._
25+
import org.apache.spark.sql.catalyst.rules._
26+
27+
/**
28+
* This test suite ensures that the [[PushDownPredicates]] actually does predicate pushdown in
29+
* an efficient manner. This is enforced by asserting that a single predicate pushdown can push
30+
* all predicate to bottom as much as possible.
31+
*/
32+
class FilterPushdownOnePassSuite extends PlanTest {
33+
34+
object Optimize extends RuleExecutor[LogicalPlan] {
35+
val batches =
36+
Batch("Subqueries", Once,
37+
EliminateSubqueryAliases) ::
38+
// this batch must reach expected state in one pass
39+
Batch("Filter Pushdown One Pass", Once,
40+
ReorderJoin,
41+
PushDownPredicates
42+
) :: Nil
43+
}
44+
45+
val testRelation1 = LocalRelation('a.int, 'b.int, 'c.int)
46+
val testRelation2 = LocalRelation('a.int, 'd.int, 'e.int)
47+
48+
test("really simple predicate push down") {
49+
val x = testRelation1.subquery('x)
50+
val y = testRelation2.subquery('y)
51+
52+
val originalQuery = x.join(y).where("x.a".attr === 1)
53+
54+
val optimized = Optimize.execute(originalQuery.analyze)
55+
val correctAnswer = x.where("x.a".attr === 1).join(y).analyze
56+
57+
comparePlans(optimized, correctAnswer)
58+
}
59+
60+
test("push down conjunctive predicates") {
61+
val x = testRelation1.subquery('x)
62+
val y = testRelation2.subquery('y)
63+
64+
val originalQuery = x.join(y).where("x.a".attr === 1 && "y.d".attr < 1)
65+
66+
val optimized = Optimize.execute(originalQuery.analyze)
67+
val correctAnswer = x.where("x.a".attr === 1).join(y.where("y.d".attr < 1)).analyze
68+
69+
comparePlans(optimized, correctAnswer)
70+
}
71+
72+
test("push down predicates for simple joins") {
73+
val x = testRelation1.subquery('x)
74+
val y = testRelation2.subquery('y)
75+
76+
val originalQuery =
77+
x.where("x.c".attr < 0)
78+
.join(y.where("y.d".attr > 1))
79+
.where("x.a".attr === 1 && "y.d".attr < 2)
80+
81+
val optimized = Optimize.execute(originalQuery.analyze)
82+
val correctAnswer =
83+
x.where("x.c".attr < 0 && "x.a".attr === 1)
84+
.join(y.where("y.d".attr > 1 && "y.d".attr < 2)).analyze
85+
86+
comparePlans(optimized, correctAnswer)
87+
}
88+
89+
test("push down top-level filters for cascading joins") {
90+
val x = testRelation1.subquery('x)
91+
val y = testRelation2.subquery('y)
92+
93+
val originalQuery =
94+
y.join(x).join(x).join(x).join(x).join(x).where("y.d".attr === 0)
95+
96+
val optimized = Optimize.execute(originalQuery.analyze)
97+
val correctAnswer = y.where("y.d".attr === 0).join(x).join(x).join(x).join(x).join(x).analyze
98+
99+
comparePlans(optimized, correctAnswer)
100+
}
101+
102+
test("push down predicates for tree-like joins") {
103+
val x = testRelation1.subquery('x)
104+
val y1 = testRelation2.subquery('y1)
105+
val y2 = testRelation2.subquery('y2)
106+
107+
val originalQuery =
108+
y1.join(x).join(x)
109+
.join(y2.join(x).join(x))
110+
.where("y1.d".attr === 0 && "y2.d".attr === 3)
111+
112+
val optimized = Optimize.execute(originalQuery.analyze)
113+
val correctAnswer =
114+
y1.where("y1.d".attr === 0).join(x).join(x)
115+
.join(y2.where("y2.d".attr === 3).join(x).join(x)).analyze
116+
117+
comparePlans(optimized, correctAnswer)
118+
}
119+
120+
test("push down through join and project") {
121+
val x = testRelation1.subquery('x)
122+
val y = testRelation2.subquery('y)
123+
124+
val originalQuery =
125+
x.where('a > 0).select('a, 'b)
126+
.join(y.where('d < 100).select('e))
127+
.where("x.a".attr < 100)
128+
129+
val optimized = Optimize.execute(originalQuery.analyze)
130+
val correctAnswer =
131+
x.where('a > 0 && 'a < 100).select('a, 'b)
132+
.join(y.where('d < 100).select('e)).analyze
133+
134+
comparePlans(optimized, correctAnswer)
135+
}
136+
137+
test("push down through deep projects") {
138+
val x = testRelation1.subquery('x)
139+
140+
val originalQuery =
141+
x.select(('a + 1) as 'a1, 'b)
142+
.select(('a1 + 1) as 'a2, 'b)
143+
.select(('a2 + 1) as 'a3, 'b)
144+
.select(('a3 + 1) as 'a4, 'b)
145+
.select('b)
146+
.where('b > 0)
147+
148+
val optimized = Optimize.execute(originalQuery.analyze)
149+
val correctAnswer =
150+
x.where('b > 0)
151+
.select(('a + 1) as 'a1, 'b)
152+
.select(('a1 + 1) as 'a2, 'b)
153+
.select(('a2 + 1) as 'a3, 'b)
154+
.select(('a3 + 1) as 'a4, 'b)
155+
.select('b).analyze
156+
157+
comparePlans(optimized, correctAnswer)
158+
}
159+
160+
test("push down through aggregate and join") {
161+
val x = testRelation1.subquery('x)
162+
val y = testRelation2.subquery('y)
163+
164+
val left = x
165+
.where('c > 0)
166+
.groupBy('a)('a, count('b))
167+
.subquery('left)
168+
val right = y
169+
.where('d < 0)
170+
.groupBy('a)('a, count('d))
171+
.subquery('right)
172+
val originalQuery = left
173+
.join(right).where("left.a".attr < 100 && "right.a".attr < 100)
174+
175+
val optimized = Optimize.execute(originalQuery.analyze)
176+
val correctAnswer =
177+
x.where('c > 0 && 'a < 100).groupBy('a)('a, count('b))
178+
.join(y.where('d < 0 && 'a < 100).groupBy('a)('a, count('d)))
179+
.analyze
180+
181+
comparePlans(optimized, correctAnswer)
182+
}
183+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class FilterPushdownSuite extends PlanTest {
3535
EliminateSubqueryAliases) ::
3636
Batch("Filter Pushdown", FixedPoint(10),
3737
CombineFilters,
38-
PushDownPredicate,
38+
PushPredicateThroughNonJoin,
3939
BooleanSimplification,
4040
PushPredicateThroughJoin,
4141
CollapseProject) :: Nil

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
3131
val batches =
3232
Batch("InferAndPushDownFilters", FixedPoint(100),
3333
PushPredicateThroughJoin,
34-
PushDownPredicate,
34+
PushPredicateThroughNonJoin,
3535
InferFiltersFromConstraints,
3636
CombineFilters,
3737
SimplifyBinaryComparison,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class JoinOptimizationSuite extends PlanTest {
3434
EliminateSubqueryAliases) ::
3535
Batch("Filter Pushdown", FixedPoint(100),
3636
CombineFilters,
37-
PushDownPredicate,
37+
PushPredicateThroughNonJoin,
3838
BooleanSimplification,
3939
ReorderJoin,
4040
PushPredicateThroughJoin,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
3535
EliminateResolvedHint) ::
3636
Batch("Operator Optimizations", FixedPoint(100),
3737
CombineFilters,
38-
PushDownPredicate,
38+
PushPredicateThroughNonJoin,
3939
ReorderJoin,
4040
PushPredicateThroughJoin,
4141
ColumnPruning,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class LeftSemiPushdownSuite extends PlanTest {
3535
EliminateSubqueryAliases) ::
3636
Batch("Filter Pushdown", FixedPoint(10),
3737
CombineFilters,
38-
PushDownPredicate,
38+
PushPredicateThroughNonJoin,
3939
PushDownLeftSemiAntiJoin,
4040
PushLeftSemiLeftAntiThroughJoin,
4141
BooleanSimplification,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class OptimizerLoggingSuite extends PlanTest {
3434
object Optimize extends RuleExecutor[LogicalPlan] {
3535
val batches =
3636
Batch("Optimizer Batch", FixedPoint(100),
37-
PushDownPredicate, ColumnPruning, CollapseProject) ::
37+
PushPredicateThroughNonJoin, ColumnPruning, CollapseProject) ::
3838
Batch("Batch Has No Effect", Once,
3939
ColumnPruning) :: Nil
4040
}
@@ -99,7 +99,7 @@ class OptimizerLoggingSuite extends PlanTest {
9999
verifyLog(
100100
level._2,
101101
Seq(
102-
PushDownPredicate.ruleName,
102+
PushPredicateThroughNonJoin.ruleName,
103103
ColumnPruning.ruleName,
104104
CollapseProject.ruleName))
105105
}
@@ -123,15 +123,15 @@ class OptimizerLoggingSuite extends PlanTest {
123123

124124
test("test log rules") {
125125
val rulesSeq = Seq(
126-
Seq(PushDownPredicate.ruleName,
126+
Seq(PushPredicateThroughNonJoin.ruleName,
127127
ColumnPruning.ruleName,
128128
CollapseProject.ruleName).reduce(_ + "," + _) ->
129-
Seq(PushDownPredicate.ruleName,
129+
Seq(PushPredicateThroughNonJoin.ruleName,
130130
ColumnPruning.ruleName,
131131
CollapseProject.ruleName),
132-
Seq(PushDownPredicate.ruleName,
132+
Seq(PushPredicateThroughNonJoin.ruleName,
133133
ColumnPruning.ruleName).reduce(_ + "," + _) ->
134-
Seq(PushDownPredicate.ruleName,
134+
Seq(PushPredicateThroughNonJoin.ruleName,
135135
ColumnPruning.ruleName),
136136
CollapseProject.ruleName ->
137137
Seq(CollapseProject.ruleName),

0 commit comments

Comments
 (0)