From 054594f692e58a203e88b40a201475d9fc153011 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 5 Jun 2020 00:09:38 -0700 Subject: [PATCH 01/18] add rule PushCNFPredicateThroughJoin --- .../PushCNFPredicateThroughJoin.scala | 139 ++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 15 ++ 2 files changed, 154 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala new file mode 100644 index 0000000000000..bd19d63ed9d1c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{And, Expression, Not, Or, PredicateHelper} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Try converting join condition to conjunctive normal form expression so that more predicates may + * be able to be pushed down. + * To avoid expanding the join condition, the join condition will be kept in the original form even + * when predicate pushdown happens. + */ +object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + /** + * Convert an expression into conjunctive normal form. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting Or clauses. + * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. + * + * @param condition to be conversed into CNF. + * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. + * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. + * Otherwise, return the converted result as sequence of disjunctive expressions. + */ + protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new scala.collection.mutable.Stack[Seq[Expression]] + val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right: Seq[Expression] = resultStack.pop() + val left: Seq[Expression] = resultStack.pop() + left ++ right + case _: Or => + val right: Seq[Expression] = resultStack.pop() + val left: Seq[Expression] = resultStack.pop() + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + Seq.empty + } else { + for {x <- left; y <- right} yield Or(x, y) + } + case other => other :: Nil + } + if (cnf.isEmpty) { + return Seq.empty + } + resultStack.push(cnf) + } + assert(resultStack.length == 1, + s"Fail to convert expression ${condition} to conjunctive normal form") + resultStack.top + } + + /** + * Iterative post order traversal over a binary tree built by And/Or clauses. + * @param condition to be traversed as binary tree + * @return sub-expressions in post order traversal as an Array. + * The first element of result Array is the leftmost node. + */ + private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { + val stack = new mutable.Stack[Expression] + val result = new mutable.Stack[Expression] + stack.push(condition) + while (stack.nonEmpty) { + val node = stack.pop() + node match { + case Not(a And b) => stack.push(Or(Not(a), Not(b))) + case Not(a Or b) => stack.push(And(Not(a), Not(b))) + case Not(Not(a)) => stack.push(a) + case a And b => + result.push(node) + stack.push(a) + stack.push(b) + case a Or b => + result.push(node) + stack.push(a) + stack.push(b) + case _ => + result.push(node) + } + } + result + } + + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(left, right, joinType, Some(joinCondition), hint) => + val predicates = conjunctiveNormalForm(joinCondition) + if (predicates.isEmpty || predicates.size > SQLConf.get.maxCnfNodeCount) { + j + } else { + val pushDownCandidates = predicates.filter(_.deterministic) + val leftFilterConditions = pushDownCandidates.filter(_.references.subsetOf(left.outputSet)) + val rightFilterConditions = + pushDownCandidates.filter(_.references.subsetOf(right.outputSet)) + + val newLeft = + leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = + rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + + joinType match { + case _: InnerLike | LeftSemi => + Join(newLeft, newRight, joinType, Some(joinCondition), hint) + case RightOuter => + Join(newLeft, right, RightOuter, Some(joinCondition), hint) + case LeftOuter | LeftAnti | ExistenceJoin(_) => + Join(left, newRight, joinType, Some(joinCondition), hint) + case FullOuter => j + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") + } + } + } +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 189740e313207..b64c89046a243 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -545,6 +545,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_CNF_NODE_COUNT = + buildConf("spark.sql.optimizer.maxCNFNodeCount") + .internal() + .doc("Specifies the maximum allowable number of conjuncts in the result of CNF " + + "conversion. If the conversion exceeds the threshold, None is returned. " + + "For example, CNF conversion of (a && b) || (c && d) generates " + + "four conjuncts (a || c) && (a || d) && (b || c) && (b || d).") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, + "The depth of the maximum rewriting conjunction normal form must be positive.") + .createWithDefault(10) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2874,6 +2887,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) From 659168062aa15f4fc3c2db1d43f9a45afbc86112 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 5 Jun 2020 00:13:42 -0700 Subject: [PATCH 02/18] From Yuming:[SPARK-31705][SQL] Push predicate through join by rewriting join condition to conjunctive normal form --- .../sql/catalyst/optimizer/Optimizer.scala | 9 +- .../optimizer/FilterPushdownSuite.scala | 159 +++++++++++++++++- 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f1a307b1c2cc1..a1a7213664ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val blacklistedOnceBatches: Set[String] = Set( "PartitionPruning", - "Extract Python UDFs") + "Extract Python UDFs", + "Push CNF predicate through join") protected def fixedPoint = FixedPoint( @@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: + // Set strategy to Once to avoid pushing filter every time because we do not change the + // join condition. + Batch("Push CNF predicate through join", Once, + PushCNFPredicateThroughJoin) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70e29dca46e9e..64c093b5ed6c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -25,12 +25,17 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { + + override protected val blacklistedOnceBatches: Set[String] = + Set("Push predicate through join by CNF") + val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: @@ -39,7 +44,9 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughNonJoin, BooleanSimplification, PushPredicateThroughJoin, - CollapseProject) :: Nil + CollapseProject) :: + Batch("Push predicate through join by CNF", Once, + PushCNFPredicateThroughJoin) :: Nil } val attrA = 'a.int @@ -1230,4 +1237,154 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), expected) } + + test("inner join: rewrite filter predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && Not(("x.a".attr > 3) + && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13))) + && (("x.a".attr <= 1) || ("y.a".attr <= 11)))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("left join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + + test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + Seq(0, 10).foreach { depth => + withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> depth.toString) { + val optimized = Optimize.execute(originalQuery.analyze) + val (left, right) = if (depth == 0) { + (testRelation.subquery('x), testRelation.subquery('y)) + } else { + (testRelation.subquery('x), + testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) + } + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + } + } } From 3497b3c63848236274050a78b6d7bf19b314772d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 5 Jun 2020 00:55:59 -0700 Subject: [PATCH 03/18] fix test failure --- .../PushCNFPredicateThroughJoin.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 32 ++++--------------- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index bd19d63ed9d1c..c0be070b9dbab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -136,4 +136,4 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe } } } -} \ No newline at end of file +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b64c89046a243..e2612f09ea173 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -556,7 +556,7 @@ object SQLConf { .intConf .checkValue(_ >= 0, "The depth of the maximum rewriting conjunction normal form must be positive.") - .createWithDefault(10) + .createWithDefault(20) val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 64c093b5ed6c5..1220278bad37c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -34,7 +34,7 @@ class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { override protected val blacklistedOnceBatches: Set[String] = - Set("Push predicate through join by CNF") + Set("Push CNF predicate through join") val batches = Batch("Subqueries", Once, @@ -45,7 +45,7 @@ class FilterPushdownSuite extends PlanTest { BooleanSimplification, PushPredicateThroughJoin, CollapseProject) :: - Batch("Push predicate through join by CNF", Once, + Batch("Push CNF predicate through join", Once, PushCNFPredicateThroughJoin) :: Nil } @@ -1340,26 +1340,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - - val originalQuery = { - x.join(y, condition = Some(("x.b".attr === "y.b".attr) - && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) - || (("y.a".attr > 2) && ("y.c".attr < 1))))) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.subquery('x) - val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) - val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr - && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) - || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze - - comparePlans(optimized, correctAnswer) - } - test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1370,14 +1350,14 @@ class FilterPushdownSuite extends PlanTest { || (("y.a".attr > 2) && ("y.c".attr < 1))))) } - Seq(0, 10).foreach { depth => - withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> depth.toString) { + Seq(0, 10).foreach { count => + withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> count.toString) { val optimized = Optimize.execute(originalQuery.analyze) - val (left, right) = if (depth == 0) { + val (left, right) = if (count == 0) { (testRelation.subquery('x), testRelation.subquery('y)) } else { (testRelation.subquery('x), - testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) + testRelation.where(('c <= 5 || 'c < 1) && ('c <=5 || 'a > 2)).subquery('y)) } val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) From ada4135b271ceaf526b04cac432c12d19c2a8e3c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 5 Jun 2020 23:09:40 +0800 Subject: [PATCH 04/18] Add test rewrite complex join predicates to conjunctive normal form --- .../optimizer/FilterPushdownSuite.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1220278bad37c..06ddfeb02f898 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -1279,6 +1279,26 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("inner join: rewrite complex join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val joinCondition = (("x.b".attr === "y.b".attr) + && ((("x.a".attr === 5) && ("y.a".attr >= 2) && ("y.a".attr <= 3)) + || (("x.a".attr === 2) && ("y.a".attr >= 1) && ("y.a".attr <= 14)) + || (("x.a".attr === 1) && ("y.a".attr >= 9) && ("y.a".attr <= 27)))) + + val originalQuery = x.join(y, condition = Some(joinCondition)) + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where( + ('a === 5 || 'a === 2 || 'a === 1)).subquery('x) + val right = testRelation.where( + ('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y) + val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze + + comparePlans(optimized, correctAnswer) + } + test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) From 7e4b0190a5c5a0aab47134b31897cde730da8685 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 5 Jun 2020 11:58:33 -0700 Subject: [PATCH 05/18] increase threshold and test case --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../sql/catalyst/optimizer/FilterPushdownSuite.scala | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e2612f09ea173..55204b2346efd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -556,7 +556,7 @@ object SQLConf { .intConf .checkValue(_ >= 0, "The depth of the maximum rewriting conjunction normal form must be positive.") - .createWithDefault(20) + .createWithDefault(256) val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 06ddfeb02f898..63220e642dec6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -1293,7 +1293,14 @@ class FilterPushdownSuite extends PlanTest { val left = testRelation.where( ('a === 5 || 'a === 2 || 'a === 1)).subquery('x) val right = testRelation.where( - ('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y) + ('a >= 2 || 'a >= 1 || 'a >= 9) && + ('a >= 2 || 'a >= 1 || 'a <= 27) && + ('a >= 2 || 'a <=14 || 'a >= 9) && + ('a >= 2 || 'a <=14 || 'a <= 27) && + ('a <= 3 || 'a >= 1 || 'a >= 9) && + ('a <= 3 || 'a >= 1 || 'a <= 27) && + ('a <= 3 || 'a <=14 || 'a >= 9) && + ('a <= 3 || 'a <=14 || 'a <= 27)).subquery('y) val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze comparePlans(optimized, correctAnswer) From 0c92c1cfac98981d1f78bb5691f46d9276251d88 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 5 Jun 2020 19:28:16 -0700 Subject: [PATCH 06/18] fix test case; reduce threshold default value --- .../PushCNFPredicateThroughJoin.scala | 10 +++++-- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 30 +++++++++++++------ 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index c0be070b9dbab..ef155d0f96522 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -55,8 +55,11 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe val left: Seq[Expression] = resultStack.pop() left ++ right case _: Or => - val right: Seq[Expression] = resultStack.pop() - val left: Seq[Expression] = resultStack.pop() + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same references as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = aggregateExpressionsOfSameReference(resultStack.pop()) + val left = aggregateExpressionsOfSameReference(resultStack.pop()) // Stop the loop whenever the result exceeds the `maxCnfNodeCount` if (left.size * right.size > maxCnfNodeCount) { Seq.empty @@ -75,6 +78,9 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe resultStack.top } + private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq + } /** * Iterative post order traversal over a binary tree built by And/Or clauses. * @param condition to be traversed as binary tree diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 55204b2346efd..37bc116a72dfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -556,7 +556,7 @@ object SQLConf { .intConf .checkValue(_ >= 0, "The depth of the maximum rewriting conjunction normal form must be positive.") - .createWithDefault(256) + .createWithDefault(128) val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 63220e642dec6..aad8fdf395fd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -1293,14 +1293,7 @@ class FilterPushdownSuite extends PlanTest { val left = testRelation.where( ('a === 5 || 'a === 2 || 'a === 1)).subquery('x) val right = testRelation.where( - ('a >= 2 || 'a >= 1 || 'a >= 9) && - ('a >= 2 || 'a >= 1 || 'a <= 27) && - ('a >= 2 || 'a <=14 || 'a >= 9) && - ('a >= 2 || 'a <=14 || 'a <= 27) && - ('a <= 3 || 'a >= 1 || 'a >= 9) && - ('a <= 3 || 'a >= 1 || 'a <= 27) && - ('a <= 3 || 'a <=14 || 'a >= 9) && - ('a <= 3 || 'a <=14 || 'a <= 27)).subquery('y) + ('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y) val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze comparePlans(optimized, correctAnswer) @@ -1367,6 +1360,25 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) && + ("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr && + ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) || + (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1384,7 +1396,7 @@ class FilterPushdownSuite extends PlanTest { (testRelation.subquery('x), testRelation.subquery('y)) } else { (testRelation.subquery('x), - testRelation.where(('c <= 5 || 'c < 1) && ('c <=5 || 'a > 2)).subquery('y)) + testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) } val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) From 7853f676714419267aea3be7e7903013fe951615 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 7 Jun 2020 22:36:29 -0700 Subject: [PATCH 07/18] address comments and add test cases --- .../sql/catalyst/expressions/predicates.scala | 83 ++++++++++++ .../PushCNFPredicateThroughJoin.scala | 85 +----------- .../ConjunctiveNormalFormPredicateSuite.scala | 128 ++++++++++++++++++ 3 files changed, 212 insertions(+), 84 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2c4f41f98ac20..589a6663bd420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.TreeSet +import scala.collection.mutable import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.InternalRow @@ -198,6 +199,88 @@ trait PredicateHelper { case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } + + /** + * Convert an expression into conjunctive normal form. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting Or clauses. + * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. + * + * @param condition to be conversed into CNF. + * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. + * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. + * Otherwise, return the converted result as sequence of disjunctive expressions. + */ + def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new mutable.Stack[Seq[Expression]] + val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right: Seq[Expression] = resultStack.pop() + val left: Seq[Expression] = resultStack.pop() + left ++ right + case _: Or => + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same references as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = aggregateExpressionsOfSameQualifiers(resultStack.pop()) + val left = aggregateExpressionsOfSameQualifiers(resultStack.pop()) + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + Seq.empty + } else { + for {x <- left; y <- right} yield Or(x, y) + } + case other => other :: Nil + } + if (cnf.isEmpty) { + return Seq.empty + } + resultStack.push(cnf) + } + assert(resultStack.length == 1, + s"Fail to convert expression ${condition} to conjunctive normal form") + resultStack.top + } + + private def aggregateExpressionsOfSameQualifiers( + expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq + } + + /** + * Iterative post order traversal over a binary tree built by And/Or clauses. + * @param condition to be traversed as binary tree + * @return sub-expressions in post order traversal as an Array. + * The first element of result Array is the leftmost node. + */ + private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { + val stack = new mutable.Stack[Expression] + val result = new mutable.Stack[Expression] + stack.push(condition) + while (stack.nonEmpty) { + val node = stack.pop() + node match { + case Not(a And b) => stack.push(Or(Not(a), Not(b))) + case Not(a Or b) => stack.push(And(Not(a), Not(b))) + case Not(Not(a)) => stack.push(a) + case a And b => + result.push(node) + stack.push(a) + stack.push(b) + case a Or b => + result.push(node) + stack.push(a) + stack.push(b) + case _ => + result.push(node) + } + } + result + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index ef155d0f96522..48b460d1e6cf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.expressions.{And, Expression, Not, Or, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule @@ -32,87 +30,6 @@ import org.apache.spark.sql.internal.SQLConf * when predicate pushdown happens. */ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { - /** - * Convert an expression into conjunctive normal form. - * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting Or clauses. - * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. - * - * @param condition to be conversed into CNF. - * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. - * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. - * Otherwise, return the converted result as sequence of disjunctive expressions. - */ - protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { - val postOrderNodes = postOrderTraversal(condition) - val resultStack = new scala.collection.mutable.Stack[Seq[Expression]] - val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount - // Bottom up approach to get CNF of sub-expressions - while (postOrderNodes.nonEmpty) { - val cnf = postOrderNodes.pop() match { - case _: And => - val right: Seq[Expression] = resultStack.pop() - val left: Seq[Expression] = resultStack.pop() - left ++ right - case _: Or => - // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same references as one single predicate, - // for reducing the size of pushed down predicates and corresponding codegen. - val right = aggregateExpressionsOfSameReference(resultStack.pop()) - val left = aggregateExpressionsOfSameReference(resultStack.pop()) - // Stop the loop whenever the result exceeds the `maxCnfNodeCount` - if (left.size * right.size > maxCnfNodeCount) { - Seq.empty - } else { - for {x <- left; y <- right} yield Or(x, y) - } - case other => other :: Nil - } - if (cnf.isEmpty) { - return Seq.empty - } - resultStack.push(cnf) - } - assert(resultStack.length == 1, - s"Fail to convert expression ${condition} to conjunctive normal form") - resultStack.top - } - - private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq - } - /** - * Iterative post order traversal over a binary tree built by And/Or clauses. - * @param condition to be traversed as binary tree - * @return sub-expressions in post order traversal as an Array. - * The first element of result Array is the leftmost node. - */ - private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { - val stack = new mutable.Stack[Expression] - val result = new mutable.Stack[Expression] - stack.push(condition) - while (stack.nonEmpty) { - val node = stack.pop() - node match { - case Not(a And b) => stack.push(Or(Not(a), Not(b))) - case Not(a Or b) => stack.push(And(Not(a), Not(b))) - case Not(Not(a)) => stack.push(a) - case a And b => - result.push(node) - stack.push(a) - stack.push(b) - case a Or b => - result.push(node) - stack.push(a) - stack.push(b) - case _ => - result.push(node) - } - } - result - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) => val predicates = conjunctiveNormalForm(joinCondition) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala new file mode 100644 index 0000000000000..2ae63d9d144dc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType + +class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHelper with PlanTest { + private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)).withQualifier(Seq("ta")) + private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)).withQualifier(Seq("tb")) + private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)).withQualifier(Seq("tc")) + private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td")) + private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te")) + private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf")) + private val g = AttributeReference("C", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg")) + private val h = AttributeReference("D", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th")) + private val i = AttributeReference("E", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti")) + private val j = AttributeReference("F", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj")) + private val a1 = + AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta")) + private val a2 = + AttributeReference("a2", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("ta")) + private val b1 = + AttributeReference("b1", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("tb")) + + // Check CNF conversion with expected expression, assuming the input has non-empty result. + private def checkCondition(input: Expression, expected: Expression): Unit = { + val cnf = conjunctiveNormalForm(input) + assert(cnf.nonEmpty) + val result = cnf.reduceLeft(And) + assert(result.semanticEquals(expected)) + } + + test("Keep non-predicated expressions") { + checkCondition(a, a) + checkCondition(Literal(1), Literal(1)) + } + + test("Conversion of Not") { + checkCondition(!a, !a) + checkCondition(!(!a), a) + checkCondition(!(!(a && b)), a && b) + checkCondition(!(!(a || b)), a || b) + checkCondition(!(a || b), !a && !b) + checkCondition(!(a && b), !a || !b) + } + + test("Conversion of And") { + checkCondition(a && b, a && b) + checkCondition(a && b && c, a && b && c) + checkCondition(a && (b || c), a && (b || c)) + checkCondition((a || b) && c, (a || b) && c) + checkCondition(a && b && c && d, a && b && c && d) + } + + test("Conversion of Or") { + checkCondition(a || b, a || b) + checkCondition(a || b || c, a || b || c) + checkCondition(a || b || c || d, a || b || c || d) + checkCondition((a && b) || c, (a || c) && (b || c)) + checkCondition((a && b) || (c && d), (a || c) && (a || d) && (b || c) && (b || d)) + } + + test("More complex cases") { + checkCondition(a && !(b || c), a && !b && !c) + checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d)) + checkCondition(a || b || c && d, (a || b || c) && (a || b || d)) + checkCondition(a || (b && c || d), (a || b || d) && (a || c || d)) + checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e)) + checkCondition(((a && b) || c) || (d || e), (a || c || d || e) && (b || c || d || e)) + + checkCondition( + (a && b && c) || (d && e && f), + (a || d) && (a || e) && (a || f) && (b || d) && (b || e) && (b || f) && + (c || d) && (c || e) && (c || f) + ) + } + + test("Aggregate predicate of same qualifiers to avoid expanding") { + checkCondition(((a && b && a1) || c), ((a && a1) || c) && (b ||c)) + checkCondition(((a && a1 && b) || c), ((a && a1) || c) && (b ||c)) + checkCondition(((b && d && a && a1) || c), ((a && a1) || c) && (b ||c) && (d || c)) + checkCondition(((b && a2 && d && a && a1) || c), ((a2 && a && a1) || c) && (b ||c) && (d || c)) + checkCondition(((b && d && a && a1 && b1) || c), + ((a && a1) || c) && ((b && b1) ||c) && (d || c)) + checkCondition((a && a1) || (b && b1), (a && a1) || (b && b1)) + checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1))) + } + + test("Return None when exceeding MAX_CNF_NODE_COUNT") { + // The following expression contains 36 conjunctive sub-expressions in CNF + val input = (a && b && c) || (d && e && f) || (g && h && i && j) + // The following expression contains 9 conjunctive sub-expressions in CNF + val input2 = (a && b && c) || (d && e && f) + Seq(8, 9, 10, 35, 36, 37).foreach { maxCount => + withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) { + if (maxCount < 36) { + assert(conjunctiveNormalForm(input).isEmpty) + } else { + assert(conjunctiveNormalForm(input).nonEmpty) + } + if (maxCount < 9) { + assert(conjunctiveNormalForm(input2).isEmpty) + } else { + assert(conjunctiveNormalForm(input2).nonEmpty) + } + } + } + } +} From 84e89deac51f4e49cecba98c24686e81630942da Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 8 Jun 2020 00:11:35 -0700 Subject: [PATCH 08/18] revise --- .../expressions/ConjunctiveNormalFormPredicateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala index 2ae63d9d144dc..a089273159afd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -105,7 +105,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1))) } - test("Return None when exceeding MAX_CNF_NODE_COUNT") { + test("Return Seq.empty when exceeding MAX_CNF_NODE_COUNT") { // The following expression contains 36 conjunctive sub-expressions in CNF val input = (a && b && c) || (d && e && f) || (g && h && i && j) // The following expression contains 9 conjunctive sub-expressions in CNF From 0af4c48a31eb3ba2e49877f533d72c39e470913c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 9 Jun 2020 02:32:27 -0700 Subject: [PATCH 09/18] lazy val --- .../catalyst/optimizer/PushCNFPredicateThroughJoin.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index 48b460d1e6cf4..1ef18703397b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -37,13 +37,14 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe j } else { val pushDownCandidates = predicates.filter(_.deterministic) - val leftFilterConditions = pushDownCandidates.filter(_.references.subsetOf(left.outputSet)) - val rightFilterConditions = + lazy val leftFilterConditions = + pushDownCandidates.filter(_.references.subsetOf(left.outputSet)) + lazy val rightFilterConditions = pushDownCandidates.filter(_.references.subsetOf(right.outputSet)) - val newLeft = + lazy val newLeft = leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) - val newRight = + lazy val newRight = rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) joinType match { From 95ee45e19446c9744277a76da6797cd0f5f5042c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 9 Jun 2020 13:18:16 -0700 Subject: [PATCH 10/18] update doc --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 37bc116a72dfa..33f40b47d072b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -549,7 +549,7 @@ object SQLConf { buildConf("spark.sql.optimizer.maxCNFNodeCount") .internal() .doc("Specifies the maximum allowable number of conjuncts in the result of CNF " + - "conversion. If the conversion exceeds the threshold, None is returned. " + + "conversion. If the conversion exceeds the threshold, an empty sequence is returned. " + "For example, CNF conversion of (a && b) || (c && d) generates " + "four conjuncts (a || c) && (a || d) && (b || c) && (b || d).") .version("3.1.0") From 6bf474742c5e2c7017d9e376b886d814a338ebaf Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 9 Jun 2020 15:16:07 -0700 Subject: [PATCH 11/18] remove assert --- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 589a6663bd420..42c82d36d4c36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -241,8 +241,6 @@ trait PredicateHelper { } resultStack.push(cnf) } - assert(resultStack.length == 1, - s"Fail to convert expression ${condition} to conjunctive normal form") resultStack.top } From fa03b00865f5cf407336f17fc2c543c437219235 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 9 Jun 2020 15:45:12 -0700 Subject: [PATCH 12/18] add back warning --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 42c82d36d4c36..8c703e8928baf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.TreeSet import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -96,7 +97,7 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr } } -trait PredicateHelper { +trait PredicateHelper extends Logging { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { case And(cond1, cond2) => @@ -239,6 +240,10 @@ trait PredicateHelper { if (cnf.isEmpty) { return Seq.empty } + if (resultStack.length != 1) { + logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + + "be something wrong with CNF conversion.") + } resultStack.push(cnf) } resultStack.top From c225f7443a96368edd59457f469a06193535011f Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 9 Jun 2020 23:51:31 -0700 Subject: [PATCH 13/18] address comments --- .../sql/catalyst/expressions/predicates.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8c703e8928baf..61d65f986f65a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -243,6 +243,7 @@ trait PredicateHelper extends Logging { if (resultStack.length != 1) { logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + "be something wrong with CNF conversion.") + return Seq.empty } resultStack.push(cnf) } @@ -255,10 +256,19 @@ trait PredicateHelper extends Logging { } /** - * Iterative post order traversal over a binary tree built by And/Or clauses. + * Iterative post order traversal over a binary tree built by And/Or clauses with two stacks. + * For example, a condition `(a And b) Or c`, the postorder traversal is + * (`a`,`b`, `And`, `c`, `Or`). + * Following is the complete algorithm. After step 2, we get the postorder traversal in + * the second stack. + * 1. Push root to first stack. + * 2. Loop while first stack is not empty + * 2.1 Pop a node from first stack and push it to second stack + * 2.2 Push the children of the popped node to first stack + * * @param condition to be traversed as binary tree - * @return sub-expressions in post order traversal as an Array. - * The first element of result Array is the leftmost node. + * @return sub-expressions in post order traversal as a stack. + * The first element of result stack is the leftmost node. */ private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { val stack = new mutable.Stack[Expression] From 296068cb1f3f47e3a897387a518a1c8a2b83175f Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 10 Jun 2020 00:10:49 -0700 Subject: [PATCH 14/18] address comments --- .../sql/catalyst/expressions/predicates.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 61d65f986f65a..bd585a3b1c97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -227,30 +227,27 @@ trait PredicateHelper extends Logging { // For each side, there is no need to expand predicates of the same references. // So here we can aggregate predicates of the same references as one single predicate, // for reducing the size of pushed down predicates and corresponding codegen. - val right = aggregateExpressionsOfSameQualifiers(resultStack.pop()) - val left = aggregateExpressionsOfSameQualifiers(resultStack.pop()) + val right = groupExpressionsByQualifier(resultStack.pop()) + val left = groupExpressionsByQualifier(resultStack.pop()) // Stop the loop whenever the result exceeds the `maxCnfNodeCount` if (left.size * right.size > maxCnfNodeCount) { - Seq.empty + return Seq.empty } else { for {x <- left; y <- right} yield Or(x, y) } case other => other :: Nil } - if (cnf.isEmpty) { - return Seq.empty - } - if (resultStack.length != 1) { - logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + - "be something wrong with CNF conversion.") - return Seq.empty - } resultStack.push(cnf) } + if (resultStack.length != 1) { + logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + + "be something wrong with CNF conversion.") + return Seq.empty + } resultStack.top } - private def aggregateExpressionsOfSameQualifiers( + private def groupExpressionsByQualifier( expressions: Seq[Expression]): Seq[Expression] = { expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq } From 377f9d85809329974f5f349c155717f9cbd40da3 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 10 Jun 2020 01:31:00 -0700 Subject: [PATCH 15/18] address comments --- .../catalyst/optimizer/PushCNFPredicateThroughJoin.scala | 3 +-- .../expressions/ConjunctiveNormalFormPredicateSuite.scala | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index 1ef18703397b0..f406b7d77ab63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SQLConf /** * Try converting join condition to conjunctive normal form expression so that more predicates may @@ -33,7 +32,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) => val predicates = conjunctiveNormalForm(joinCondition) - if (predicates.isEmpty || predicates.size > SQLConf.get.maxCnfNodeCount) { + if (predicates.isEmpty) { j } else { val pushDownCandidates = predicates.filter(_.deterministic) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala index a089273159afd..b449ed5cc0d07 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -30,10 +30,10 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td")) private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te")) private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf")) - private val g = AttributeReference("C", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg")) - private val h = AttributeReference("D", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th")) - private val i = AttributeReference("E", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti")) - private val j = AttributeReference("F", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj")) + private val g = AttributeReference("G", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg")) + private val h = AttributeReference("H", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th")) + private val i = AttributeReference("I", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti")) + private val j = AttributeReference("J", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj")) private val a1 = AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta")) private val a2 = From be79ab7e2833b8c554cdfc3496a44c1eba51de9c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 10 Jun 2020 02:08:01 -0700 Subject: [PATCH 16/18] update test case --- .../optimizer/FilterPushdownSuite.scala | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index aad8fdf395fd5..c1812d9e76380 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -58,6 +58,15 @@ class FilterPushdownSuite extends PlanTest { val testRelation1 = LocalRelation(attrD) + val simpleDisjuncitvePredicate = + ("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11) + val expectedCNFPredicatePushDownResult = { + val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))).analyze + } + // This test already passes. test("eliminate subqueries") { val originalQuery = @@ -1244,19 +1253,11 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y) - .where(("x.b".attr === "y.b".attr) - && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))) + .where(("x.b".attr === "y.b".attr) && (simpleDisjuncitvePredicate)) } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) - val right = testRelation.where('a > 13 || 'a > 11).subquery('y) - val correctAnswer = - left.join(right, condition = Some("x.b".attr === "y.b".attr - && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) - .analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, expectedCNFPredicatePushDownResult) } test("inner join: rewrite join predicates to conjunctive normal form") { @@ -1264,19 +1265,11 @@ class FilterPushdownSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = { - x.join(y, condition = Some(("x.b".attr === "y.b".attr) - && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjuncitvePredicate))) } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a > 3 || 'a > 1).subquery('x) - val right = testRelation.where('a > 13 || 'a > 11).subquery('y) - val correctAnswer = - left.join(right, condition = Some("x.b".attr === "y.b".attr - && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) - .analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, expectedCNFPredicatePushDownResult) } test("inner join: rewrite complex join predicates to conjunctive normal form") { @@ -1326,7 +1319,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) - && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + && simpleDisjuncitvePredicate)) } val optimized = Optimize.execute(originalQuery.analyze) @@ -1346,7 +1339,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) - && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + && simpleDisjuncitvePredicate)) } val optimized = Optimize.execute(originalQuery.analyze) From af018be5aa50d5fd19f6acec7754dd7ca60f1852 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 10 Jun 2020 15:10:55 -0700 Subject: [PATCH 17/18] address comments --- .../sql/catalyst/expressions/predicates.scala | 25 ++++++++------- .../PushCNFPredicateThroughJoin.scala | 3 +- .../optimizer/FilterPushdownSuite.scala | 31 +++++++------------ 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bd585a3b1c97d..c9b57367e0f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -204,13 +204,12 @@ trait PredicateHelper extends Logging { /** * Convert an expression into conjunctive normal form. * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting Or clauses. - * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. + * CNF can explode exponentially in the size of the input expression when converting [[Or]] + * clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases. * - * @param condition to be conversed into CNF. - * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. - * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. - * Otherwise, return the converted result as sequence of disjunctive expressions. + * @param condition to be converted into CNF. + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { val postOrderNodes = postOrderTraversal(condition) @@ -220,20 +219,23 @@ trait PredicateHelper extends Logging { while (postOrderNodes.nonEmpty) { val cnf = postOrderNodes.pop() match { case _: And => - val right: Seq[Expression] = resultStack.pop() - val left: Seq[Expression] = resultStack.pop() + val right = resultStack.pop() + val left = resultStack.pop() left ++ right case _: Or => // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same references as one single predicate, + // So here we can aggregate predicates of the same qualifier as one single predicate, // for reducing the size of pushed down predicates and corresponding codegen. val right = groupExpressionsByQualifier(resultStack.pop()) val left = groupExpressionsByQualifier(resultStack.pop()) // Stop the loop whenever the result exceeds the `maxCnfNodeCount` if (left.size * right.size > maxCnfNodeCount) { + logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + + "The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " + + s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.") return Seq.empty } else { - for {x <- left; y <- right} yield Or(x, y) + for { x <- left; y <- right } yield Or(x, y) } case other => other :: Nil } @@ -247,8 +249,7 @@ trait PredicateHelper extends Logging { resultStack.top } - private def groupExpressionsByQualifier( - expressions: Seq[Expression]): Seq[Expression] = { + private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = { expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index f406b7d77ab63..c7848b9b9ea4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ Join(left, right, joinType, Some(joinCondition), hint) => + case j @ Join(left, right, joinType, Some(joinCondition), hint) if joinType != FullOuter => val predicates = conjunctiveNormalForm(joinCondition) if (predicates.isEmpty) { j @@ -53,7 +53,6 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe Join(newLeft, right, RightOuter, Some(joinCondition), hint) case LeftOuter | LeftAnti | ExistenceJoin(_) => Join(left, newRight, joinType, Some(joinCondition), hint) - case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index c1812d9e76380..bb8f5f90f8508 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -58,7 +58,7 @@ class FilterPushdownSuite extends PlanTest { val testRelation1 = LocalRelation(attrD) - val simpleDisjuncitvePredicate = + val simpleDisjunctivePredicate = ("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11) val expectedCNFPredicatePushDownResult = { val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) @@ -1251,10 +1251,7 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { - x.join(y) - .where(("x.b".attr === "y.b".attr) && (simpleDisjuncitvePredicate)) - } + val originalQuery = x.join(y).where(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, expectedCNFPredicatePushDownResult) @@ -1264,9 +1261,8 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { - x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjuncitvePredicate))) - } + val originalQuery = + x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, expectedCNFPredicatePushDownResult) @@ -1296,11 +1292,10 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { + val originalQuery = x.join(y, condition = Some(("x.b".attr === "y.b".attr) && Not(("x.a".attr > 3) && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) - } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) @@ -1317,10 +1312,9 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { + val originalQuery = x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) - && simpleDisjuncitvePredicate)) - } + && simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.subquery('x) @@ -1337,10 +1331,9 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { + val originalQuery = x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) - && simpleDisjuncitvePredicate)) - } + && simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a > 3 || 'a > 1).subquery('x) @@ -1357,10 +1350,9 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { + val originalQuery = x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1))))) - } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.subquery('x) @@ -1376,11 +1368,10 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val originalQuery = { + val originalQuery = x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1))))) - } Seq(0, 10).foreach { count => withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> count.toString) { From b42ce1dde86069a74ebbda44a1729cde39c2672d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 10 Jun 2020 17:52:28 -0700 Subject: [PATCH 18/18] fix build --- .../sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index c7848b9b9ea4a..f406b7d77ab63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ Join(left, right, joinType, Some(joinCondition), hint) if joinType != FullOuter => + case j @ Join(left, right, joinType, Some(joinCondition), hint) => val predicates = conjunctiveNormalForm(joinCondition) if (predicates.isEmpty) { j @@ -53,6 +53,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe Join(newLeft, right, RightOuter, Some(joinCondition), hint) case LeftOuter | LeftAnti | ExistenceJoin(_) => Join(left, newRight, joinType, Some(joinCondition), hint) + case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") }