diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2c4f41f98ac20..c9b57367e0f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.TreeSet +import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -95,7 +97,7 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr } } -trait PredicateHelper { +trait PredicateHelper extends Logging { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { case And(cond1, cond2) => @@ -198,6 +200,98 @@ trait PredicateHelper { case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } + + /** + * Convert an expression into conjunctive normal form. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting [[Or]] + * clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases. + * + * @param condition to be converted into CNF. + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. + */ + def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new mutable.Stack[Seq[Expression]] + val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right = resultStack.pop() + val left = resultStack.pop() + left ++ right + case _: Or => + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same qualifier as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = groupExpressionsByQualifier(resultStack.pop()) + val left = groupExpressionsByQualifier(resultStack.pop()) + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + + "The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " + + s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.") + return Seq.empty + } else { + for { x <- left; y <- right } yield Or(x, y) + } + case other => other :: Nil + } + resultStack.push(cnf) + } + if (resultStack.length != 1) { + logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + + "be something wrong with CNF conversion.") + return Seq.empty + } + resultStack.top + } + + private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq + } + + /** + * Iterative post order traversal over a binary tree built by And/Or clauses with two stacks. + * For example, a condition `(a And b) Or c`, the postorder traversal is + * (`a`,`b`, `And`, `c`, `Or`). + * Following is the complete algorithm. After step 2, we get the postorder traversal in + * the second stack. + * 1. Push root to first stack. + * 2. Loop while first stack is not empty + * 2.1 Pop a node from first stack and push it to second stack + * 2.2 Push the children of the popped node to first stack + * + * @param condition to be traversed as binary tree + * @return sub-expressions in post order traversal as a stack. + * The first element of result stack is the leftmost node. + */ + private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { + val stack = new mutable.Stack[Expression] + val result = new mutable.Stack[Expression] + stack.push(condition) + while (stack.nonEmpty) { + val node = stack.pop() + node match { + case Not(a And b) => stack.push(Or(Not(a), Not(b))) + case Not(a Or b) => stack.push(And(Not(a), Not(b))) + case Not(Not(a)) => stack.push(a) + case a And b => + result.push(node) + stack.push(a) + stack.push(b) + case a Or b => + result.push(node) + stack.push(a) + stack.push(b) + case _ => + result.push(node) + } + } + result + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f1a307b1c2cc1..a1a7213664ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val blacklistedOnceBatches: Set[String] = Set( "PartitionPruning", - "Extract Python UDFs") + "Extract Python UDFs", + "Push CNF predicate through join") protected def fixedPoint = FixedPoint( @@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: + // Set strategy to Once to avoid pushing filter every time because we do not change the + // join condition. + Batch("Push CNF predicate through join", Once, + PushCNFPredicateThroughJoin) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala new file mode 100644 index 0000000000000..f406b7d77ab63 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Try converting join condition to conjunctive normal form expression so that more predicates may + * be able to be pushed down. + * To avoid expanding the join condition, the join condition will be kept in the original form even + * when predicate pushdown happens. + */ +object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(left, right, joinType, Some(joinCondition), hint) => + val predicates = conjunctiveNormalForm(joinCondition) + if (predicates.isEmpty) { + j + } else { + val pushDownCandidates = predicates.filter(_.deterministic) + lazy val leftFilterConditions = + pushDownCandidates.filter(_.references.subsetOf(left.outputSet)) + lazy val rightFilterConditions = + pushDownCandidates.filter(_.references.subsetOf(right.outputSet)) + + lazy val newLeft = + leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + lazy val newRight = + rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + + joinType match { + case _: InnerLike | LeftSemi => + Join(newLeft, newRight, joinType, Some(joinCondition), hint) + case RightOuter => + Join(newLeft, right, RightOuter, Some(joinCondition), hint) + case LeftOuter | LeftAnti | ExistenceJoin(_) => + Join(left, newRight, joinType, Some(joinCondition), hint) + case FullOuter => j + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 189740e313207..33f40b47d072b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -545,6 +545,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_CNF_NODE_COUNT = + buildConf("spark.sql.optimizer.maxCNFNodeCount") + .internal() + .doc("Specifies the maximum allowable number of conjuncts in the result of CNF " + + "conversion. If the conversion exceeds the threshold, an empty sequence is returned. " + + "For example, CNF conversion of (a && b) || (c && d) generates " + + "four conjuncts (a || c) && (a || d) && (b || c) && (b || d).") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, + "The depth of the maximum rewriting conjunction normal form must be positive.") + .createWithDefault(128) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2874,6 +2887,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala new file mode 100644 index 0000000000000..b449ed5cc0d07 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType + +class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHelper with PlanTest { + private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)).withQualifier(Seq("ta")) + private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)).withQualifier(Seq("tb")) + private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)).withQualifier(Seq("tc")) + private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td")) + private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te")) + private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf")) + private val g = AttributeReference("G", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg")) + private val h = AttributeReference("H", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th")) + private val i = AttributeReference("I", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti")) + private val j = AttributeReference("J", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj")) + private val a1 = + AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta")) + private val a2 = + AttributeReference("a2", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("ta")) + private val b1 = + AttributeReference("b1", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("tb")) + + // Check CNF conversion with expected expression, assuming the input has non-empty result. + private def checkCondition(input: Expression, expected: Expression): Unit = { + val cnf = conjunctiveNormalForm(input) + assert(cnf.nonEmpty) + val result = cnf.reduceLeft(And) + assert(result.semanticEquals(expected)) + } + + test("Keep non-predicated expressions") { + checkCondition(a, a) + checkCondition(Literal(1), Literal(1)) + } + + test("Conversion of Not") { + checkCondition(!a, !a) + checkCondition(!(!a), a) + checkCondition(!(!(a && b)), a && b) + checkCondition(!(!(a || b)), a || b) + checkCondition(!(a || b), !a && !b) + checkCondition(!(a && b), !a || !b) + } + + test("Conversion of And") { + checkCondition(a && b, a && b) + checkCondition(a && b && c, a && b && c) + checkCondition(a && (b || c), a && (b || c)) + checkCondition((a || b) && c, (a || b) && c) + checkCondition(a && b && c && d, a && b && c && d) + } + + test("Conversion of Or") { + checkCondition(a || b, a || b) + checkCondition(a || b || c, a || b || c) + checkCondition(a || b || c || d, a || b || c || d) + checkCondition((a && b) || c, (a || c) && (b || c)) + checkCondition((a && b) || (c && d), (a || c) && (a || d) && (b || c) && (b || d)) + } + + test("More complex cases") { + checkCondition(a && !(b || c), a && !b && !c) + checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d)) + checkCondition(a || b || c && d, (a || b || c) && (a || b || d)) + checkCondition(a || (b && c || d), (a || b || d) && (a || c || d)) + checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e)) + checkCondition(((a && b) || c) || (d || e), (a || c || d || e) && (b || c || d || e)) + + checkCondition( + (a && b && c) || (d && e && f), + (a || d) && (a || e) && (a || f) && (b || d) && (b || e) && (b || f) && + (c || d) && (c || e) && (c || f) + ) + } + + test("Aggregate predicate of same qualifiers to avoid expanding") { + checkCondition(((a && b && a1) || c), ((a && a1) || c) && (b ||c)) + checkCondition(((a && a1 && b) || c), ((a && a1) || c) && (b ||c)) + checkCondition(((b && d && a && a1) || c), ((a && a1) || c) && (b ||c) && (d || c)) + checkCondition(((b && a2 && d && a && a1) || c), ((a2 && a && a1) || c) && (b ||c) && (d || c)) + checkCondition(((b && d && a && a1 && b1) || c), + ((a && a1) || c) && ((b && b1) ||c) && (d || c)) + checkCondition((a && a1) || (b && b1), (a && a1) || (b && b1)) + checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1))) + } + + test("Return Seq.empty when exceeding MAX_CNF_NODE_COUNT") { + // The following expression contains 36 conjunctive sub-expressions in CNF + val input = (a && b && c) || (d && e && f) || (g && h && i && j) + // The following expression contains 9 conjunctive sub-expressions in CNF + val input2 = (a && b && c) || (d && e && f) + Seq(8, 9, 10, 35, 36, 37).foreach { maxCount => + withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) { + if (maxCount < 36) { + assert(conjunctiveNormalForm(input).isEmpty) + } else { + assert(conjunctiveNormalForm(input).nonEmpty) + } + if (maxCount < 9) { + assert(conjunctiveNormalForm(input2).isEmpty) + } else { + assert(conjunctiveNormalForm(input2).nonEmpty) + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70e29dca46e9e..bb8f5f90f8508 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -25,12 +25,17 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { + + override protected val blacklistedOnceBatches: Set[String] = + Set("Push CNF predicate through join") + val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: @@ -39,7 +44,9 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughNonJoin, BooleanSimplification, PushPredicateThroughJoin, - CollapseProject) :: Nil + CollapseProject) :: + Batch("Push CNF predicate through join", Once, + PushCNFPredicateThroughJoin) :: Nil } val attrA = 'a.int @@ -51,6 +58,15 @@ class FilterPushdownSuite extends PlanTest { val testRelation1 = LocalRelation(attrD) + val simpleDisjunctivePredicate = + ("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11) + val expectedCNFPredicatePushDownResult = { + val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))).analyze + } + // This test already passes. test("eliminate subqueries") { val originalQuery = @@ -1230,4 +1246,148 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), expected) } + + test("inner join: rewrite filter predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y).where(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, expectedCNFPredicatePushDownResult) + } + + test("inner join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate))) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, expectedCNFPredicatePushDownResult) + } + + test("inner join: rewrite complex join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val joinCondition = (("x.b".attr === "y.b".attr) + && ((("x.a".attr === 5) && ("y.a".attr >= 2) && ("y.a".attr <= 3)) + || (("x.a".attr === 2) && ("y.a".attr >= 1) && ("y.a".attr <= 14)) + || (("x.a".attr === 1) && ("y.a".attr >= 9) && ("y.a".attr <= 27)))) + + val originalQuery = x.join(y, condition = Some(joinCondition)) + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where( + ('a === 5 || 'a === 2 || 'a === 1)).subquery('x) + val right = testRelation.where( + ('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y) + val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && Not(("x.a".attr > 3) + && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13))) + && (("x.a".attr <= 1) || ("y.a".attr <= 11)))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("left join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) + && simpleDisjunctivePredicate)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) + && simpleDisjunctivePredicate)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) && + ("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1))))) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr && + ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) || + (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + + test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + + Seq(0, 10).foreach { count => + withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> count.toString) { + val optimized = Optimize.execute(originalQuery.analyze) + val (left, right) = if (count == 0) { + (testRelation.subquery('x), testRelation.subquery('y)) + } else { + (testRelation.subquery('x), + testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) + } + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + } + } }