Skip to content

Commit 76e3825

Browse files
committed
fix test case; reduce threshold default value
1 parent 8952a6a commit 76e3825

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
5555
val left: Seq[Expression] = resultStack.pop()
5656
left ++ right
5757
case _: Or =>
58-
val right: Seq[Expression] = resultStack.pop()
59-
val left: Seq[Expression] = resultStack.pop()
58+
// For each side, there is no need to expand predicates of the same references.
59+
// So here we can aggregate predicates of the same references as one single predicate,
60+
// for reducing the size of pushed down predicates and corresponding codegen.
61+
val right = aggregateExpressionsOfSameReference(resultStack.pop())
62+
val left = aggregateExpressionsOfSameReference(resultStack.pop())
6063
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
6164
if (left.size * right.size > maxCnfNodeCount) {
6265
Seq.empty
@@ -75,6 +78,9 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
7578
resultStack.top
7679
}
7780

81+
private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = {
82+
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
83+
}
7884
/**
7985
* Iterative post order traversal over a binary tree built by And/Or clauses.
8086
* @param condition to be traversed as binary tree

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ object SQLConf {
556556
.intConf
557557
.checkValue(_ >= 0,
558558
"The depth of the maximum rewriting conjunction normal form must be positive.")
559-
.createWithDefault(256)
559+
.createWithDefault(128)
560560

561561
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
562562
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,14 +1293,7 @@ class FilterPushdownSuite extends PlanTest {
12931293
val left = testRelation.where(
12941294
('a === 5 || 'a === 2 || 'a === 1)).subquery('x)
12951295
val right = testRelation.where(
1296-
('a >= 2 || 'a >= 1 || 'a >= 9) &&
1297-
('a >= 2 || 'a >= 1 || 'a <= 27) &&
1298-
('a >= 2 || 'a <=14 || 'a >= 9) &&
1299-
('a >= 2 || 'a <=14 || 'a <= 27) &&
1300-
('a <= 3 || 'a >= 1 || 'a >= 9) &&
1301-
('a <= 3 || 'a >= 1 || 'a <= 27) &&
1302-
('a <= 3 || 'a <=14 || 'a >= 9) &&
1303-
('a <= 3 || 'a <=14 || 'a <= 27)).subquery('y)
1296+
('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y)
13041297
val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze
13051298

13061299
comparePlans(optimized, correctAnswer)
@@ -1367,6 +1360,25 @@ class FilterPushdownSuite extends PlanTest {
13671360
comparePlans(optimized, correctAnswer)
13681361
}
13691362

1363+
test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") {
1364+
val x = testRelation.subquery('x)
1365+
val y = testRelation.subquery('y)
1366+
1367+
val originalQuery = {
1368+
x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) &&
1369+
("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1)))))
1370+
}
1371+
1372+
val optimized = Optimize.execute(originalQuery.analyze)
1373+
val left = testRelation.subquery('x)
1374+
val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)
1375+
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr &&
1376+
((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) ||
1377+
(("y.a".attr > 2) && ("y.c".attr < 1))))).analyze
1378+
1379+
comparePlans(optimized, correctAnswer)
1380+
}
1381+
13701382
test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") {
13711383
val x = testRelation.subquery('x)
13721384
val y = testRelation.subquery('y)
@@ -1384,7 +1396,7 @@ class FilterPushdownSuite extends PlanTest {
13841396
(testRelation.subquery('x), testRelation.subquery('y))
13851397
} else {
13861398
(testRelation.subquery('x),
1387-
testRelation.where(('c <= 5 || 'c < 1) && ('c <=5 || 'a > 2)).subquery('y))
1399+
testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y))
13881400
}
13891401
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
13901402
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))

0 commit comments

Comments
 (0)