Skip to content

Commit 11d3a74

Browse files
[SPARK-31705][SQL] Push more possible predicates through Join via CNF conversion
### What changes were proposed in this pull request? This PR add a new rule to support push predicate through join by rewriting join condition to CNF(conjunctive normal form). The following example is the steps of this rule: 1. Prepare Table: ```sql CREATE TABLE x(a INT); CREATE TABLE y(b INT); ... SELECT * FROM x JOIN y ON ((a < 0 and a > b) or a > 10); ``` 2. Convert the join condition to CNF: ``` (a < 0 or a > 10) and (a > b or a > 10) ``` 3. Split conjunctive predicates Predicates ---| (a < 0 or a > 10) (a > b or a > 10) 4. Push predicate Table | Predicate --- | --- x | (a < 0 or a > 10) ### Why are the changes needed? Improve query performance. PostgreSQL, [Impala](https://issues.apache.org/jira/browse/IMPALA-9183) and Hive support this feature. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test and benchmark test. SQL | Before this PR | After this PR --- | --- | --- TPCDS 5T Q13 | 84s | 21s TPCDS 5T q85 | 66s | 34s TPCH 1T q19 | 37s | 32s Closes #28733 from gengliangwang/cnf. Lead-authored-by: Gengliang Wang <[email protected]> Co-authored-by: Yuming Wang <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 91cd06b commit 11d3a74

File tree

6 files changed

+468
-4
lines changed

6 files changed

+468
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import scala.collection.immutable.TreeSet
21+
import scala.collection.mutable
2122

23+
import org.apache.spark.internal.Logging
2224
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
2325
import org.apache.spark.sql.catalyst.InternalRow
2426
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
@@ -95,7 +97,7 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr
9597
}
9698
}
9799

98-
trait PredicateHelper {
100+
trait PredicateHelper extends Logging {
99101
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
100102
condition match {
101103
case And(cond1, cond2) =>
@@ -198,6 +200,98 @@ trait PredicateHelper {
198200
case e: Unevaluable => false
199201
case e => e.children.forall(canEvaluateWithinJoin)
200202
}
203+
204+
/**
205+
* Convert an expression into conjunctive normal form.
206+
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
207+
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
208+
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
209+
*
210+
* @param condition to be converted into CNF.
211+
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
212+
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
213+
*/
214+
def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
215+
val postOrderNodes = postOrderTraversal(condition)
216+
val resultStack = new mutable.Stack[Seq[Expression]]
217+
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
218+
// Bottom up approach to get CNF of sub-expressions
219+
while (postOrderNodes.nonEmpty) {
220+
val cnf = postOrderNodes.pop() match {
221+
case _: And =>
222+
val right = resultStack.pop()
223+
val left = resultStack.pop()
224+
left ++ right
225+
case _: Or =>
226+
// For each side, there is no need to expand predicates of the same references.
227+
// So here we can aggregate predicates of the same qualifier as one single predicate,
228+
// for reducing the size of pushed down predicates and corresponding codegen.
229+
val right = groupExpressionsByQualifier(resultStack.pop())
230+
val left = groupExpressionsByQualifier(resultStack.pop())
231+
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
232+
if (left.size * right.size > maxCnfNodeCount) {
233+
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
234+
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
235+
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
236+
return Seq.empty
237+
} else {
238+
for { x <- left; y <- right } yield Or(x, y)
239+
}
240+
case other => other :: Nil
241+
}
242+
resultStack.push(cnf)
243+
}
244+
if (resultStack.length != 1) {
245+
logWarning("The length of CNF conversion result stack is supposed to be 1. There might " +
246+
"be something wrong with CNF conversion.")
247+
return Seq.empty
248+
}
249+
resultStack.top
250+
}
251+
252+
private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
253+
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
254+
}
255+
256+
/**
257+
* Iterative post order traversal over a binary tree built by And/Or clauses with two stacks.
258+
* For example, a condition `(a And b) Or c`, the postorder traversal is
259+
* (`a`,`b`, `And`, `c`, `Or`).
260+
* Following is the complete algorithm. After step 2, we get the postorder traversal in
261+
* the second stack.
262+
* 1. Push root to first stack.
263+
* 2. Loop while first stack is not empty
264+
* 2.1 Pop a node from first stack and push it to second stack
265+
* 2.2 Push the children of the popped node to first stack
266+
*
267+
* @param condition to be traversed as binary tree
268+
* @return sub-expressions in post order traversal as a stack.
269+
* The first element of result stack is the leftmost node.
270+
*/
271+
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
272+
val stack = new mutable.Stack[Expression]
273+
val result = new mutable.Stack[Expression]
274+
stack.push(condition)
275+
while (stack.nonEmpty) {
276+
val node = stack.pop()
277+
node match {
278+
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
279+
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
280+
case Not(Not(a)) => stack.push(a)
281+
case a And b =>
282+
result.push(node)
283+
stack.push(a)
284+
stack.push(b)
285+
case a Or b =>
286+
result.push(node)
287+
stack.push(a)
288+
stack.push(b)
289+
case _ =>
290+
result.push(node)
291+
}
292+
}
293+
result
294+
}
201295
}
202296

203297
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
5151
override protected val blacklistedOnceBatches: Set[String] =
5252
Set(
5353
"PartitionPruning",
54-
"Extract Python UDFs")
54+
"Extract Python UDFs",
55+
"Push CNF predicate through join")
5556

5657
protected def fixedPoint =
5758
FixedPoint(
@@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
118119
Batch("Infer Filters", Once,
119120
InferFiltersFromConstraints) ::
120121
Batch("Operator Optimization after Inferring Filters", fixedPoint,
121-
rulesWithoutInferFiltersFromConstraints: _*) :: Nil
122+
rulesWithoutInferFiltersFromConstraints: _*) ::
123+
// Set strategy to Once to avoid pushing filter every time because we do not change the
124+
// join condition.
125+
Batch("Push CNF predicate through join", Once,
126+
PushCNFPredicateThroughJoin) :: Nil
122127
}
123128

124129
val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper}
21+
import org.apache.spark.sql.catalyst.plans._
22+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
25+
/**
26+
* Try converting join condition to conjunctive normal form expression so that more predicates may
27+
* be able to be pushed down.
28+
* To avoid expanding the join condition, the join condition will be kept in the original form even
29+
* when predicate pushdown happens.
30+
*/
31+
object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
32+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
33+
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
34+
val predicates = conjunctiveNormalForm(joinCondition)
35+
if (predicates.isEmpty) {
36+
j
37+
} else {
38+
val pushDownCandidates = predicates.filter(_.deterministic)
39+
lazy val leftFilterConditions =
40+
pushDownCandidates.filter(_.references.subsetOf(left.outputSet))
41+
lazy val rightFilterConditions =
42+
pushDownCandidates.filter(_.references.subsetOf(right.outputSet))
43+
44+
lazy val newLeft =
45+
leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
46+
lazy val newRight =
47+
rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
48+
49+
joinType match {
50+
case _: InnerLike | LeftSemi =>
51+
Join(newLeft, newRight, joinType, Some(joinCondition), hint)
52+
case RightOuter =>
53+
Join(newLeft, right, RightOuter, Some(joinCondition), hint)
54+
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
55+
Join(left, newRight, joinType, Some(joinCondition), hint)
56+
case FullOuter => j
57+
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
58+
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
59+
}
60+
}
61+
}
62+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,19 @@ object SQLConf {
545545
.booleanConf
546546
.createWithDefault(true)
547547

548+
val MAX_CNF_NODE_COUNT =
549+
buildConf("spark.sql.optimizer.maxCNFNodeCount")
550+
.internal()
551+
.doc("Specifies the maximum allowable number of conjuncts in the result of CNF " +
552+
"conversion. If the conversion exceeds the threshold, an empty sequence is returned. " +
553+
"For example, CNF conversion of (a && b) || (c && d) generates " +
554+
"four conjuncts (a || c) && (a || d) && (b || c) && (b || d).")
555+
.version("3.1.0")
556+
.intConf
557+
.checkValue(_ >= 0,
558+
"The depth of the maximum rewriting conjunction normal form must be positive.")
559+
.createWithDefault(128)
560+
548561
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
549562
.internal()
550563
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
@@ -2874,6 +2887,8 @@ class SQLConf extends Serializable with Logging {
28742887

28752888
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
28762889

2890+
def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT)
2891+
28772892
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
28782893

28792894
def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.plans.PlanTest
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.BooleanType
25+
26+
class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHelper with PlanTest {
27+
private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)).withQualifier(Seq("ta"))
28+
private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)).withQualifier(Seq("tb"))
29+
private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)).withQualifier(Seq("tc"))
30+
private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td"))
31+
private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te"))
32+
private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf"))
33+
private val g = AttributeReference("G", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg"))
34+
private val h = AttributeReference("H", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th"))
35+
private val i = AttributeReference("I", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti"))
36+
private val j = AttributeReference("J", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj"))
37+
private val a1 =
38+
AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta"))
39+
private val a2 =
40+
AttributeReference("a2", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("ta"))
41+
private val b1 =
42+
AttributeReference("b1", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("tb"))
43+
44+
// Check CNF conversion with expected expression, assuming the input has non-empty result.
45+
private def checkCondition(input: Expression, expected: Expression): Unit = {
46+
val cnf = conjunctiveNormalForm(input)
47+
assert(cnf.nonEmpty)
48+
val result = cnf.reduceLeft(And)
49+
assert(result.semanticEquals(expected))
50+
}
51+
52+
test("Keep non-predicated expressions") {
53+
checkCondition(a, a)
54+
checkCondition(Literal(1), Literal(1))
55+
}
56+
57+
test("Conversion of Not") {
58+
checkCondition(!a, !a)
59+
checkCondition(!(!a), a)
60+
checkCondition(!(!(a && b)), a && b)
61+
checkCondition(!(!(a || b)), a || b)
62+
checkCondition(!(a || b), !a && !b)
63+
checkCondition(!(a && b), !a || !b)
64+
}
65+
66+
test("Conversion of And") {
67+
checkCondition(a && b, a && b)
68+
checkCondition(a && b && c, a && b && c)
69+
checkCondition(a && (b || c), a && (b || c))
70+
checkCondition((a || b) && c, (a || b) && c)
71+
checkCondition(a && b && c && d, a && b && c && d)
72+
}
73+
74+
test("Conversion of Or") {
75+
checkCondition(a || b, a || b)
76+
checkCondition(a || b || c, a || b || c)
77+
checkCondition(a || b || c || d, a || b || c || d)
78+
checkCondition((a && b) || c, (a || c) && (b || c))
79+
checkCondition((a && b) || (c && d), (a || c) && (a || d) && (b || c) && (b || d))
80+
}
81+
82+
test("More complex cases") {
83+
checkCondition(a && !(b || c), a && !b && !c)
84+
checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d))
85+
checkCondition(a || b || c && d, (a || b || c) && (a || b || d))
86+
checkCondition(a || (b && c || d), (a || b || d) && (a || c || d))
87+
checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e))
88+
checkCondition(((a && b) || c) || (d || e), (a || c || d || e) && (b || c || d || e))
89+
90+
checkCondition(
91+
(a && b && c) || (d && e && f),
92+
(a || d) && (a || e) && (a || f) && (b || d) && (b || e) && (b || f) &&
93+
(c || d) && (c || e) && (c || f)
94+
)
95+
}
96+
97+
test("Aggregate predicate of same qualifiers to avoid expanding") {
98+
checkCondition(((a && b && a1) || c), ((a && a1) || c) && (b ||c))
99+
checkCondition(((a && a1 && b) || c), ((a && a1) || c) && (b ||c))
100+
checkCondition(((b && d && a && a1) || c), ((a && a1) || c) && (b ||c) && (d || c))
101+
checkCondition(((b && a2 && d && a && a1) || c), ((a2 && a && a1) || c) && (b ||c) && (d || c))
102+
checkCondition(((b && d && a && a1 && b1) || c),
103+
((a && a1) || c) && ((b && b1) ||c) && (d || c))
104+
checkCondition((a && a1) || (b && b1), (a && a1) || (b && b1))
105+
checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1)))
106+
}
107+
108+
test("Return Seq.empty when exceeding MAX_CNF_NODE_COUNT") {
109+
// The following expression contains 36 conjunctive sub-expressions in CNF
110+
val input = (a && b && c) || (d && e && f) || (g && h && i && j)
111+
// The following expression contains 9 conjunctive sub-expressions in CNF
112+
val input2 = (a && b && c) || (d && e && f)
113+
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
114+
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
115+
if (maxCount < 36) {
116+
assert(conjunctiveNormalForm(input).isEmpty)
117+
} else {
118+
assert(conjunctiveNormalForm(input).nonEmpty)
119+
}
120+
if (maxCount < 9) {
121+
assert(conjunctiveNormalForm(input2).isEmpty)
122+
} else {
123+
assert(conjunctiveNormalForm(input2).nonEmpty)
124+
}
125+
}
126+
}
127+
}
128+
}

0 commit comments

Comments
 (0)