Skip to content

Commit 58e6fa5

Browse files
committed
address comments
1 parent 7c102a5 commit 58e6fa5

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,12 @@ trait PredicateHelper extends Logging {
204204
/**
205205
* Convert an expression into conjunctive normal form.
206206
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
207-
* CNF can explode exponentially in the size of the input expression when converting Or clauses.
208-
* Use a configuration MAX_CNF_NODE_COUNT to prevent such cases.
207+
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
208+
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
209209
*
210-
* @param condition to be conversed into CNF.
211-
* @return If the number of expressions exceeds threshold on converting Or, return Seq.empty.
212-
* If the conversion repeatedly expands nondeterministic expressions, return Seq.empty.
213-
* Otherwise, return the converted result as sequence of disjunctive expressions.
210+
* @param condition to be converted into CNF.
211+
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
212+
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
214213
*/
215214
def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
216215
val postOrderNodes = postOrderTraversal(condition)
@@ -220,20 +219,23 @@ trait PredicateHelper extends Logging {
220219
while (postOrderNodes.nonEmpty) {
221220
val cnf = postOrderNodes.pop() match {
222221
case _: And =>
223-
val right: Seq[Expression] = resultStack.pop()
224-
val left: Seq[Expression] = resultStack.pop()
222+
val right = resultStack.pop()
223+
val left = resultStack.pop()
225224
left ++ right
226225
case _: Or =>
227226
// For each side, there is no need to expand predicates of the same references.
228-
// So here we can aggregate predicates of the same references as one single predicate,
227+
// So here we can aggregate predicates of the same qualifier as one single predicate,
229228
// for reducing the size of pushed down predicates and corresponding codegen.
230229
val right = groupExpressionsByQualifier(resultStack.pop())
231230
val left = groupExpressionsByQualifier(resultStack.pop())
232231
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
233232
if (left.size * right.size > maxCnfNodeCount) {
233+
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
234+
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
235+
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
234236
return Seq.empty
235237
} else {
236-
for {x <- left; y <- right} yield Or(x, y)
238+
for { x <- left; y <- right } yield Or(x, y)
237239
}
238240
case other => other :: Nil
239241
}
@@ -247,8 +249,7 @@ trait PredicateHelper extends Logging {
247249
resultStack.top
248250
}
249251

250-
private def groupExpressionsByQualifier(
251-
expressions: Seq[Expression]): Seq[Expression] = {
252+
private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
252253
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
253254
}
254255

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
3030
*/
3131
object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
3232
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
33-
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
33+
case j @ Join(left, right, joinType, Some(joinCondition), hint) if joinType != FullOuter =>
3434
val predicates = conjunctiveNormalForm(joinCondition)
3535
if (predicates.isEmpty) {
3636
j
@@ -53,7 +53,6 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
5353
Join(newLeft, right, RightOuter, Some(joinCondition), hint)
5454
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
5555
Join(left, newRight, joinType, Some(joinCondition), hint)
56-
case FullOuter => j
5756
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
5857
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
5958
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class FilterPushdownSuite extends PlanTest {
5858

5959
val testRelation1 = LocalRelation(attrD)
6060

61-
val simpleDisjuncitvePredicate =
61+
val simpleDisjunctivePredicate =
6262
("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)
6363
val expectedCNFPredicatePushDownResult = {
6464
val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x)
@@ -1251,10 +1251,7 @@ class FilterPushdownSuite extends PlanTest {
12511251
val x = testRelation.subquery('x)
12521252
val y = testRelation.subquery('y)
12531253

1254-
val originalQuery = {
1255-
x.join(y)
1256-
.where(("x.b".attr === "y.b".attr) && (simpleDisjuncitvePredicate))
1257-
}
1254+
val originalQuery = x.join(y).where(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate))
12581255

12591256
val optimized = Optimize.execute(originalQuery.analyze)
12601257
comparePlans(optimized, expectedCNFPredicatePushDownResult)
@@ -1264,9 +1261,8 @@ class FilterPushdownSuite extends PlanTest {
12641261
val x = testRelation.subquery('x)
12651262
val y = testRelation.subquery('y)
12661263

1267-
val originalQuery = {
1268-
x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjuncitvePredicate)))
1269-
}
1264+
val originalQuery =
1265+
x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate)))
12701266

12711267
val optimized = Optimize.execute(originalQuery.analyze)
12721268
comparePlans(optimized, expectedCNFPredicatePushDownResult)
@@ -1296,11 +1292,10 @@ class FilterPushdownSuite extends PlanTest {
12961292
val x = testRelation.subquery('x)
12971293
val y = testRelation.subquery('y)
12981294

1299-
val originalQuery = {
1295+
val originalQuery =
13001296
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
13011297
&& Not(("x.a".attr > 3)
13021298
&& ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1303-
}
13041299

13051300
val optimized = Optimize.execute(originalQuery.analyze)
13061301
val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x)
@@ -1317,10 +1312,9 @@ class FilterPushdownSuite extends PlanTest {
13171312
val x = testRelation.subquery('x)
13181313
val y = testRelation.subquery('y)
13191314

1320-
val originalQuery = {
1315+
val originalQuery =
13211316
x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr)
1322-
&& simpleDisjuncitvePredicate))
1323-
}
1317+
&& simpleDisjunctivePredicate))
13241318

13251319
val optimized = Optimize.execute(originalQuery.analyze)
13261320
val left = testRelation.subquery('x)
@@ -1337,10 +1331,9 @@ class FilterPushdownSuite extends PlanTest {
13371331
val x = testRelation.subquery('x)
13381332
val y = testRelation.subquery('y)
13391333

1340-
val originalQuery = {
1334+
val originalQuery =
13411335
x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr)
1342-
&& simpleDisjuncitvePredicate))
1343-
}
1336+
&& simpleDisjunctivePredicate))
13441337

13451338
val optimized = Optimize.execute(originalQuery.analyze)
13461339
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
@@ -1357,10 +1350,9 @@ class FilterPushdownSuite extends PlanTest {
13571350
val x = testRelation.subquery('x)
13581351
val y = testRelation.subquery('y)
13591352

1360-
val originalQuery = {
1353+
val originalQuery =
13611354
x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) &&
13621355
("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1)))))
1363-
}
13641356

13651357
val optimized = Optimize.execute(originalQuery.analyze)
13661358
val left = testRelation.subquery('x)
@@ -1376,11 +1368,10 @@ class FilterPushdownSuite extends PlanTest {
13761368
val x = testRelation.subquery('x)
13771369
val y = testRelation.subquery('y)
13781370

1379-
val originalQuery = {
1371+
val originalQuery =
13801372
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
13811373
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
13821374
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
1383-
}
13841375

13851376
Seq(0, 10).foreach { count =>
13861377
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> count.toString) {

0 commit comments

Comments
 (0)