Skip to content

Commit effa0c8

Browse files
wangyumgengliangwang
authored andcommitted
From Yuming:[SPARK-31705][SQL] Push predicate through join by rewriting join condition to conjunctive normal form
1 parent 53e8151 commit effa0c8

File tree

2 files changed

+165
-3
lines changed

2 files changed

+165
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
5151
override protected val blacklistedOnceBatches: Set[String] =
5252
Set(
5353
"PartitionPruning",
54-
"Extract Python UDFs")
54+
"Extract Python UDFs",
55+
"Push CNF predicate through join")
5556

5657
protected def fixedPoint =
5758
FixedPoint(
@@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
118119
Batch("Infer Filters", Once,
119120
InferFiltersFromConstraints) ::
120121
Batch("Operator Optimization after Inferring Filters", fixedPoint,
121-
rulesWithoutInferFiltersFromConstraints: _*) :: Nil
122+
rulesWithoutInferFiltersFromConstraints: _*) ::
123+
// Set strategy to Once to avoid pushing filter every time because we do not change the
124+
// join condition.
125+
Batch("Push CNF predicate through join", Once,
126+
PushCNFPredicateThroughJoin) :: Nil
122127
}
123128

124129
val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.catalyst.rules._
28+
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types.{BooleanType, IntegerType}
2930
import org.apache.spark.unsafe.types.CalendarInterval
3031

3132
class FilterPushdownSuite extends PlanTest {
3233

3334
object Optimize extends RuleExecutor[LogicalPlan] {
35+
36+
override protected val blacklistedOnceBatches: Set[String] =
37+
Set("Push predicate through join by CNF")
38+
3439
val batches =
3540
Batch("Subqueries", Once,
3641
EliminateSubqueryAliases) ::
@@ -39,7 +44,9 @@ class FilterPushdownSuite extends PlanTest {
3944
PushPredicateThroughNonJoin,
4045
BooleanSimplification,
4146
PushPredicateThroughJoin,
42-
CollapseProject) :: Nil
47+
CollapseProject) ::
48+
Batch("Push predicate through join by CNF", Once,
49+
PushCNFPredicateThroughJoin) :: Nil
4350
}
4451

4552
val attrA = 'a.int
@@ -1230,4 +1237,154 @@ class FilterPushdownSuite extends PlanTest {
12301237

12311238
comparePlans(Optimize.execute(query.analyze), expected)
12321239
}
1240+
1241+
test("inner join: rewrite filter predicates to conjunctive normal form") {
1242+
val x = testRelation.subquery('x)
1243+
val y = testRelation.subquery('y)
1244+
1245+
val originalQuery = {
1246+
x.join(y)
1247+
.where(("x.b".attr === "y.b".attr)
1248+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))
1249+
}
1250+
1251+
val optimized = Optimize.execute(originalQuery.analyze)
1252+
val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x)
1253+
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
1254+
val correctAnswer =
1255+
left.join(right, condition = Some("x.b".attr === "y.b".attr
1256+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1257+
.analyze
1258+
1259+
comparePlans(optimized, correctAnswer)
1260+
}
1261+
1262+
test("inner join: rewrite join predicates to conjunctive normal form") {
1263+
val x = testRelation.subquery('x)
1264+
val y = testRelation.subquery('y)
1265+
1266+
val originalQuery = {
1267+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1268+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1269+
}
1270+
1271+
val optimized = Optimize.execute(originalQuery.analyze)
1272+
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
1273+
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
1274+
val correctAnswer =
1275+
left.join(right, condition = Some("x.b".attr === "y.b".attr
1276+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1277+
.analyze
1278+
1279+
comparePlans(optimized, correctAnswer)
1280+
}
1281+
1282+
test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") {
1283+
val x = testRelation.subquery('x)
1284+
val y = testRelation.subquery('y)
1285+
1286+
val originalQuery = {
1287+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1288+
&& Not(("x.a".attr > 3)
1289+
&& ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1290+
}
1291+
1292+
val optimized = Optimize.execute(originalQuery.analyze)
1293+
val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x)
1294+
val right = testRelation.subquery('y)
1295+
val correctAnswer =
1296+
left.join(right, condition = Some("x.b".attr === "y.b".attr
1297+
&& (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13)))
1298+
&& (("x.a".attr <= 1) || ("y.a".attr <= 11))))
1299+
.analyze
1300+
comparePlans(optimized, correctAnswer)
1301+
}
1302+
1303+
test("left join: rewrite join predicates to conjunctive normal form") {
1304+
val x = testRelation.subquery('x)
1305+
val y = testRelation.subquery('y)
1306+
1307+
val originalQuery = {
1308+
x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr)
1309+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1310+
}
1311+
1312+
val optimized = Optimize.execute(originalQuery.analyze)
1313+
val left = testRelation.subquery('x)
1314+
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
1315+
val correctAnswer =
1316+
left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr
1317+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1318+
.analyze
1319+
1320+
comparePlans(optimized, correctAnswer)
1321+
}
1322+
1323+
test("right join: rewrite join predicates to conjunctive normal form") {
1324+
val x = testRelation.subquery('x)
1325+
val y = testRelation.subquery('y)
1326+
1327+
val originalQuery = {
1328+
x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr)
1329+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1330+
}
1331+
1332+
val optimized = Optimize.execute(originalQuery.analyze)
1333+
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
1334+
val right = testRelation.subquery('y)
1335+
val correctAnswer =
1336+
left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr
1337+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1338+
.analyze
1339+
1340+
comparePlans(optimized, correctAnswer)
1341+
}
1342+
1343+
test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") {
1344+
val x = testRelation.subquery('x)
1345+
val y = testRelation.subquery('y)
1346+
1347+
val originalQuery = {
1348+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1349+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1350+
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
1351+
}
1352+
1353+
val optimized = Optimize.execute(originalQuery.analyze)
1354+
val left = testRelation.subquery('x)
1355+
val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)
1356+
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
1357+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1358+
|| (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze
1359+
1360+
comparePlans(optimized, correctAnswer)
1361+
}
1362+
1363+
test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") {
1364+
val x = testRelation.subquery('x)
1365+
val y = testRelation.subquery('y)
1366+
1367+
val originalQuery = {
1368+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1369+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1370+
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
1371+
}
1372+
1373+
Seq(0, 10).foreach { depth =>
1374+
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> depth.toString) {
1375+
val optimized = Optimize.execute(originalQuery.analyze)
1376+
val (left, right) = if (depth == 0) {
1377+
(testRelation.subquery('x), testRelation.subquery('y))
1378+
} else {
1379+
(testRelation.subquery('x),
1380+
testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y))
1381+
}
1382+
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
1383+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1384+
|| (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze
1385+
1386+
comparePlans(optimized, correctAnswer)
1387+
}
1388+
}
1389+
}
12331390
}

0 commit comments

Comments
 (0)