Skip to content

Commit 97c3414

Browse files
committed
address comments and add test cases
1 parent 76e3825 commit 97c3414

File tree

3 files changed

+212
-84
lines changed

3 files changed

+212
-84
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import scala.collection.immutable.TreeSet
21+
import scala.collection.mutable
2122

2223
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
2324
import org.apache.spark.sql.catalyst.InternalRow
@@ -198,6 +199,88 @@ trait PredicateHelper {
198199
case e: Unevaluable => false
199200
case e => e.children.forall(canEvaluateWithinJoin)
200201
}
202+
203+
/**
204+
* Convert an expression into conjunctive normal form.
205+
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
206+
* CNF can explode exponentially in the size of the input expression when converting Or clauses.
207+
* Use a configuration MAX_CNF_NODE_COUNT to prevent such cases.
208+
*
209+
* @param condition to be conversed into CNF.
210+
* @return If the number of expressions exceeds threshold on converting Or, return Seq.empty.
211+
* If the conversion repeatedly expands nondeterministic expressions, return Seq.empty.
212+
* Otherwise, return the converted result as sequence of disjunctive expressions.
213+
*/
214+
def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
215+
val postOrderNodes = postOrderTraversal(condition)
216+
val resultStack = new mutable.Stack[Seq[Expression]]
217+
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
218+
// Bottom up approach to get CNF of sub-expressions
219+
while (postOrderNodes.nonEmpty) {
220+
val cnf = postOrderNodes.pop() match {
221+
case _: And =>
222+
val right: Seq[Expression] = resultStack.pop()
223+
val left: Seq[Expression] = resultStack.pop()
224+
left ++ right
225+
case _: Or =>
226+
// For each side, there is no need to expand predicates of the same references.
227+
// So here we can aggregate predicates of the same references as one single predicate,
228+
// for reducing the size of pushed down predicates and corresponding codegen.
229+
val right = aggregateExpressionsOfSameQualifiers(resultStack.pop())
230+
val left = aggregateExpressionsOfSameQualifiers(resultStack.pop())
231+
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
232+
if (left.size * right.size > maxCnfNodeCount) {
233+
Seq.empty
234+
} else {
235+
for {x <- left; y <- right} yield Or(x, y)
236+
}
237+
case other => other :: Nil
238+
}
239+
if (cnf.isEmpty) {
240+
return Seq.empty
241+
}
242+
resultStack.push(cnf)
243+
}
244+
assert(resultStack.length == 1,
245+
s"Fail to convert expression ${condition} to conjunctive normal form")
246+
resultStack.top
247+
}
248+
249+
private def aggregateExpressionsOfSameQualifiers(
250+
expressions: Seq[Expression]): Seq[Expression] = {
251+
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
252+
}
253+
254+
/**
255+
* Iterative post order traversal over a binary tree built by And/Or clauses.
256+
* @param condition to be traversed as binary tree
257+
* @return sub-expressions in post order traversal as an Array.
258+
* The first element of result Array is the leftmost node.
259+
*/
260+
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
261+
val stack = new mutable.Stack[Expression]
262+
val result = new mutable.Stack[Expression]
263+
stack.push(condition)
264+
while (stack.nonEmpty) {
265+
val node = stack.pop()
266+
node match {
267+
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
268+
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
269+
case Not(Not(a)) => stack.push(a)
270+
case a And b =>
271+
result.push(node)
272+
stack.push(a)
273+
stack.push(b)
274+
case a Or b =>
275+
result.push(node)
276+
stack.push(a)
277+
stack.push(b)
278+
case _ =>
279+
result.push(node)
280+
}
281+
}
282+
result
283+
}
201284
}
202285

203286
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import scala.collection.mutable
21-
22-
import org.apache.spark.sql.catalyst.expressions.{And, Expression, Not, Or, PredicateHelper}
20+
import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper}
2321
import org.apache.spark.sql.catalyst.plans._
2422
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
2523
import org.apache.spark.sql.catalyst.rules.Rule
@@ -32,87 +30,6 @@ import org.apache.spark.sql.internal.SQLConf
3230
* when predicate pushdown happens.
3331
*/
3432
object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
35-
/**
36-
* Convert an expression into conjunctive normal form.
37-
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
38-
* CNF can explode exponentially in the size of the input expression when converting Or clauses.
39-
* Use a configuration MAX_CNF_NODE_COUNT to prevent such cases.
40-
*
41-
* @param condition to be conversed into CNF.
42-
* @return If the number of expressions exceeds threshold on converting Or, return Seq.empty.
43-
* If the conversion repeatedly expands nondeterministic expressions, return Seq.empty.
44-
* Otherwise, return the converted result as sequence of disjunctive expressions.
45-
*/
46-
protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
47-
val postOrderNodes = postOrderTraversal(condition)
48-
val resultStack = new scala.collection.mutable.Stack[Seq[Expression]]
49-
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
50-
// Bottom up approach to get CNF of sub-expressions
51-
while (postOrderNodes.nonEmpty) {
52-
val cnf = postOrderNodes.pop() match {
53-
case _: And =>
54-
val right: Seq[Expression] = resultStack.pop()
55-
val left: Seq[Expression] = resultStack.pop()
56-
left ++ right
57-
case _: Or =>
58-
// For each side, there is no need to expand predicates of the same references.
59-
// So here we can aggregate predicates of the same references as one single predicate,
60-
// for reducing the size of pushed down predicates and corresponding codegen.
61-
val right = aggregateExpressionsOfSameReference(resultStack.pop())
62-
val left = aggregateExpressionsOfSameReference(resultStack.pop())
63-
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
64-
if (left.size * right.size > maxCnfNodeCount) {
65-
Seq.empty
66-
} else {
67-
for {x <- left; y <- right} yield Or(x, y)
68-
}
69-
case other => other :: Nil
70-
}
71-
if (cnf.isEmpty) {
72-
return Seq.empty
73-
}
74-
resultStack.push(cnf)
75-
}
76-
assert(resultStack.length == 1,
77-
s"Fail to convert expression ${condition} to conjunctive normal form")
78-
resultStack.top
79-
}
80-
81-
private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = {
82-
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
83-
}
84-
/**
85-
* Iterative post order traversal over a binary tree built by And/Or clauses.
86-
* @param condition to be traversed as binary tree
87-
* @return sub-expressions in post order traversal as an Array.
88-
* The first element of result Array is the leftmost node.
89-
*/
90-
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
91-
val stack = new mutable.Stack[Expression]
92-
val result = new mutable.Stack[Expression]
93-
stack.push(condition)
94-
while (stack.nonEmpty) {
95-
val node = stack.pop()
96-
node match {
97-
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
98-
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
99-
case Not(Not(a)) => stack.push(a)
100-
case a And b =>
101-
result.push(node)
102-
stack.push(a)
103-
stack.push(b)
104-
case a Or b =>
105-
result.push(node)
106-
stack.push(a)
107-
stack.push(b)
108-
case _ =>
109-
result.push(node)
110-
}
111-
}
112-
result
113-
}
114-
115-
11633
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
11734
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
11835
val predicates = conjunctiveNormalForm(joinCondition)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.plans.PlanTest
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.BooleanType
25+
26+
class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHelper with PlanTest {
27+
private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)).withQualifier(Seq("ta"))
28+
private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)).withQualifier(Seq("tb"))
29+
private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)).withQualifier(Seq("tc"))
30+
private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td"))
31+
private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te"))
32+
private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf"))
33+
private val g = AttributeReference("C", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg"))
34+
private val h = AttributeReference("D", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th"))
35+
private val i = AttributeReference("E", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti"))
36+
private val j = AttributeReference("F", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj"))
37+
private val a1 =
38+
AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta"))
39+
private val a2 =
40+
AttributeReference("a2", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("ta"))
41+
private val b1 =
42+
AttributeReference("b1", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("tb"))
43+
44+
// Check CNF conversion with expected expression, assuming the input has non-empty result.
45+
private def checkCondition(input: Expression, expected: Expression): Unit = {
46+
val cnf = conjunctiveNormalForm(input)
47+
assert(cnf.nonEmpty)
48+
val result = cnf.reduceLeft(And)
49+
assert(result.semanticEquals(expected))
50+
}
51+
52+
test("Keep non-predicated expressions") {
53+
checkCondition(a, a)
54+
checkCondition(Literal(1), Literal(1))
55+
}
56+
57+
test("Conversion of Not") {
58+
checkCondition(!a, !a)
59+
checkCondition(!(!a), a)
60+
checkCondition(!(!(a && b)), a && b)
61+
checkCondition(!(!(a || b)), a || b)
62+
checkCondition(!(a || b), !a && !b)
63+
checkCondition(!(a && b), !a || !b)
64+
}
65+
66+
test("Conversion of And") {
67+
checkCondition(a && b, a && b)
68+
checkCondition(a && b && c, a && b && c)
69+
checkCondition(a && (b || c), a && (b || c))
70+
checkCondition((a || b) && c, (a || b) && c)
71+
checkCondition(a && b && c && d, a && b && c && d)
72+
}
73+
74+
test("Conversion of Or") {
75+
checkCondition(a || b, a || b)
76+
checkCondition(a || b || c, a || b || c)
77+
checkCondition(a || b || c || d, a || b || c || d)
78+
checkCondition((a && b) || c, (a || c) && (b || c))
79+
checkCondition((a && b) || (c && d), (a || c) && (a || d) && (b || c) && (b || d))
80+
}
81+
82+
test("More complex cases") {
83+
checkCondition(a && !(b || c), a && !b && !c)
84+
checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d))
85+
checkCondition(a || b || c && d, (a || b || c) && (a || b || d))
86+
checkCondition(a || (b && c || d), (a || b || d) && (a || c || d))
87+
checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e))
88+
checkCondition(((a && b) || c) || (d || e), (a || c || d || e) && (b || c || d || e))
89+
90+
checkCondition(
91+
(a && b && c) || (d && e && f),
92+
(a || d) && (a || e) && (a || f) && (b || d) && (b || e) && (b || f) &&
93+
(c || d) && (c || e) && (c || f)
94+
)
95+
}
96+
97+
test("Aggregate predicate of same qualifiers to avoid expanding") {
98+
checkCondition(((a && b && a1) || c), ((a && a1) || c) && (b ||c))
99+
checkCondition(((a && a1 && b) || c), ((a && a1) || c) && (b ||c))
100+
checkCondition(((b && d && a && a1) || c), ((a && a1) || c) && (b ||c) && (d || c))
101+
checkCondition(((b && a2 && d && a && a1) || c), ((a2 && a && a1) || c) && (b ||c) && (d || c))
102+
checkCondition(((b && d && a && a1 && b1) || c),
103+
((a && a1) || c) && ((b && b1) ||c) && (d || c))
104+
checkCondition((a && a1) || (b && b1), (a && a1) || (b && b1))
105+
checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1)))
106+
}
107+
108+
test("Return None when exceeding MAX_CNF_NODE_COUNT") {
109+
// The following expression contains 36 conjunctive sub-expressions in CNF
110+
val input = (a && b && c) || (d && e && f) || (g && h && i && j)
111+
// The following expression contains 9 conjunctive sub-expressions in CNF
112+
val input2 = (a && b && c) || (d && e && f)
113+
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
114+
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
115+
if (maxCount < 36) {
116+
assert(conjunctiveNormalForm(input).isEmpty)
117+
} else {
118+
assert(conjunctiveNormalForm(input).nonEmpty)
119+
}
120+
if (maxCount < 9) {
121+
assert(conjunctiveNormalForm(input2).isEmpty)
122+
} else {
123+
assert(conjunctiveNormalForm(input2).nonEmpty)
124+
}
125+
}
126+
}
127+
}
128+
}

0 commit comments

Comments
 (0)