@@ -25,12 +25,17 @@ import org.apache.spark.sql.catalyst.expressions._
25
25
import org .apache .spark .sql .catalyst .plans ._
26
26
import org .apache .spark .sql .catalyst .plans .logical ._
27
27
import org .apache .spark .sql .catalyst .rules ._
28
+ import org .apache .spark .sql .internal .SQLConf
28
29
import org .apache .spark .sql .types .{BooleanType , IntegerType }
29
30
import org .apache .spark .unsafe .types .CalendarInterval
30
31
31
32
class FilterPushdownSuite extends PlanTest {
32
33
33
34
object Optimize extends RuleExecutor [LogicalPlan ] {
35
+
36
+ override protected val blacklistedOnceBatches : Set [String ] =
37
+ Set (" Push predicate through join by CNF" )
38
+
34
39
val batches =
35
40
Batch (" Subqueries" , Once ,
36
41
EliminateSubqueryAliases ) ::
@@ -39,7 +44,9 @@ class FilterPushdownSuite extends PlanTest {
39
44
PushPredicateThroughNonJoin ,
40
45
BooleanSimplification ,
41
46
PushPredicateThroughJoin ,
42
- CollapseProject ) :: Nil
47
+ CollapseProject ) ::
48
+ Batch (" Push predicate through join by CNF" , Once ,
49
+ PushCNFPredicateThroughJoin ) :: Nil
43
50
}
44
51
45
52
val attrA = ' a .int
@@ -1230,4 +1237,154 @@ class FilterPushdownSuite extends PlanTest {
1230
1237
1231
1238
comparePlans(Optimize .execute(query.analyze), expected)
1232
1239
}
1240
+
1241
+ test(" inner join: rewrite filter predicates to conjunctive normal form" ) {
1242
+ val x = testRelation.subquery(' x )
1243
+ val y = testRelation.subquery(' y )
1244
+
1245
+ val originalQuery = {
1246
+ x.join(y)
1247
+ .where((" x.b" .attr === " y.b" .attr)
1248
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 )))
1249
+ }
1250
+
1251
+ val optimized = Optimize .execute(originalQuery.analyze)
1252
+ val left = testRelation.where((' a > 3 || ' a > 1 )).subquery(' x )
1253
+ val right = testRelation.where(' a > 13 || ' a > 11 ).subquery(' y )
1254
+ val correctAnswer =
1255
+ left.join(right, condition = Some (" x.b" .attr === " y.b" .attr
1256
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1257
+ .analyze
1258
+
1259
+ comparePlans(optimized, correctAnswer)
1260
+ }
1261
+
1262
+ test(" inner join: rewrite join predicates to conjunctive normal form" ) {
1263
+ val x = testRelation.subquery(' x )
1264
+ val y = testRelation.subquery(' y )
1265
+
1266
+ val originalQuery = {
1267
+ x.join(y, condition = Some ((" x.b" .attr === " y.b" .attr)
1268
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1269
+ }
1270
+
1271
+ val optimized = Optimize .execute(originalQuery.analyze)
1272
+ val left = testRelation.where(' a > 3 || ' a > 1 ).subquery(' x )
1273
+ val right = testRelation.where(' a > 13 || ' a > 11 ).subquery(' y )
1274
+ val correctAnswer =
1275
+ left.join(right, condition = Some (" x.b" .attr === " y.b" .attr
1276
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1277
+ .analyze
1278
+
1279
+ comparePlans(optimized, correctAnswer)
1280
+ }
1281
+
1282
+ test(" inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form" ) {
1283
+ val x = testRelation.subquery(' x )
1284
+ val y = testRelation.subquery(' y )
1285
+
1286
+ val originalQuery = {
1287
+ x.join(y, condition = Some ((" x.b" .attr === " y.b" .attr)
1288
+ && Not ((" x.a" .attr > 3 )
1289
+ && (" x.a" .attr < 2 || (" y.a" .attr > 13 )) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1290
+ }
1291
+
1292
+ val optimized = Optimize .execute(originalQuery.analyze)
1293
+ val left = testRelation.where(' a <= 3 || ' a >= 2 ).subquery(' x )
1294
+ val right = testRelation.subquery(' y )
1295
+ val correctAnswer =
1296
+ left.join(right, condition = Some (" x.b" .attr === " y.b" .attr
1297
+ && ((" x.a" .attr <= 3 ) || ((" x.a" .attr >= 2 ) && (" y.a" .attr <= 13 )))
1298
+ && ((" x.a" .attr <= 1 ) || (" y.a" .attr <= 11 ))))
1299
+ .analyze
1300
+ comparePlans(optimized, correctAnswer)
1301
+ }
1302
+
1303
+ test(" left join: rewrite join predicates to conjunctive normal form" ) {
1304
+ val x = testRelation.subquery(' x )
1305
+ val y = testRelation.subquery(' y )
1306
+
1307
+ val originalQuery = {
1308
+ x.join(y, joinType = LeftOuter , condition = Some ((" x.b" .attr === " y.b" .attr)
1309
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1310
+ }
1311
+
1312
+ val optimized = Optimize .execute(originalQuery.analyze)
1313
+ val left = testRelation.subquery(' x )
1314
+ val right = testRelation.where(' a > 13 || ' a > 11 ).subquery(' y )
1315
+ val correctAnswer =
1316
+ left.join(right, joinType = LeftOuter , condition = Some (" x.b" .attr === " y.b" .attr
1317
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1318
+ .analyze
1319
+
1320
+ comparePlans(optimized, correctAnswer)
1321
+ }
1322
+
1323
+ test(" right join: rewrite join predicates to conjunctive normal form" ) {
1324
+ val x = testRelation.subquery(' x )
1325
+ val y = testRelation.subquery(' y )
1326
+
1327
+ val originalQuery = {
1328
+ x.join(y, joinType = RightOuter , condition = Some ((" x.b" .attr === " y.b" .attr)
1329
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1330
+ }
1331
+
1332
+ val optimized = Optimize .execute(originalQuery.analyze)
1333
+ val left = testRelation.where(' a > 3 || ' a > 1 ).subquery(' x )
1334
+ val right = testRelation.subquery(' y )
1335
+ val correctAnswer =
1336
+ left.join(right, joinType = RightOuter , condition = Some (" x.b" .attr === " y.b" .attr
1337
+ && ((" x.a" .attr > 3 ) && (" y.a" .attr > 13 ) || (" x.a" .attr > 1 ) && (" y.a" .attr > 11 ))))
1338
+ .analyze
1339
+
1340
+ comparePlans(optimized, correctAnswer)
1341
+ }
1342
+
1343
+ test(" inner join: rewrite to conjunctive normal form avoid generating too many predicates" ) {
1344
+ val x = testRelation.subquery(' x )
1345
+ val y = testRelation.subquery(' y )
1346
+
1347
+ val originalQuery = {
1348
+ x.join(y, condition = Some ((" x.b" .attr === " y.b" .attr)
1349
+ && (((" x.a" .attr > 3 ) && (" x.a" .attr < 13 ) && (" y.c" .attr <= 5 ))
1350
+ || ((" y.a" .attr > 2 ) && (" y.c" .attr < 1 )))))
1351
+ }
1352
+
1353
+ val optimized = Optimize .execute(originalQuery.analyze)
1354
+ val left = testRelation.subquery(' x )
1355
+ val right = testRelation.where(' c <= 5 || (' a > 2 && ' c < 1 )).subquery(' y )
1356
+ val correctAnswer = left.join(right, condition = Some (" x.b" .attr === " y.b" .attr
1357
+ && (((" x.a" .attr > 3 ) && (" x.a" .attr < 13 ) && (" y.c" .attr <= 5 ))
1358
+ || ((" y.a" .attr > 2 ) && (" y.c" .attr < 1 ))))).analyze
1359
+
1360
+ comparePlans(optimized, correctAnswer)
1361
+ }
1362
+
1363
+ test(s " Disable rewrite to CNF by setting ${SQLConf .MAX_CNF_NODE_COUNT .key}=0 " ) {
1364
+ val x = testRelation.subquery(' x )
1365
+ val y = testRelation.subquery(' y )
1366
+
1367
+ val originalQuery = {
1368
+ x.join(y, condition = Some ((" x.b" .attr === " y.b" .attr)
1369
+ && (((" x.a" .attr > 3 ) && (" x.a" .attr < 13 ) && (" y.c" .attr <= 5 ))
1370
+ || ((" y.a" .attr > 2 ) && (" y.c" .attr < 1 )))))
1371
+ }
1372
+
1373
+ Seq (0 , 10 ).foreach { depth =>
1374
+ withSQLConf(SQLConf .MAX_CNF_NODE_COUNT .key -> depth.toString) {
1375
+ val optimized = Optimize .execute(originalQuery.analyze)
1376
+ val (left, right) = if (depth == 0 ) {
1377
+ (testRelation.subquery(' x ), testRelation.subquery(' y ))
1378
+ } else {
1379
+ (testRelation.subquery(' x ),
1380
+ testRelation.where(' c <= 5 || (' a > 2 && ' c < 1 )).subquery(' y ))
1381
+ }
1382
+ val correctAnswer = left.join(right, condition = Some (" x.b" .attr === " y.b" .attr
1383
+ && (((" x.a" .attr > 3 ) && (" x.a" .attr < 13 ) && (" y.c" .attr <= 5 ))
1384
+ || ((" y.a" .attr > 2 ) && (" y.c" .attr < 1 ))))).analyze
1385
+
1386
+ comparePlans(optimized, correctAnswer)
1387
+ }
1388
+ }
1389
+ }
1233
1390
}
0 commit comments