From 21fb7c5844f43ec0b9190ac6e823aba2854ba2bd Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 18 May 2020 13:41:13 +0800 Subject: [PATCH 1/4] Rewriting join condition to conjunctive normal form expression --- .../sql/catalyst/optimizer/Optimizer.scala | 78 ++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 14 +++ .../optimizer/FilterPushdownSuite.scala | 106 +++++++++++++++++- 3 files changed, 196 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e59e3b999aa7f..f510d68b3b7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,7 +118,9 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: + Batch("Push predicate through join by conjunctive normal form", Once, + PushPredicateThroughJoinByCNF) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: @@ -1372,6 +1374,80 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Rewriting join condition to conjunctive normal form expression so that we can push + * more predicate. + */ +object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Rewrite pattern: + * 1. (a && b) || c --> (a || c) && (b || c) + * 2. a || (b && c) --> (a || b) && (a || c) + * 3. !(a || b) --> !a && !b + */ + private def rewriteToCNF(condition: Expression, depth: Int = 0): Expression = { + if (depth < SQLConf.get.maxRewritingCNFDepth) { + val nextDepth = depth + 1 + condition match { + case Or(And(a, b), c) => + And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth), + rewriteToCNF(Or(rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth)) + case Or(a, And(b, c)) => + And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth), + rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth)) + case Not(Or(a, b)) => + And(rewriteToCNF(Not(rewriteToCNF(a, nextDepth)), nextDepth), + rewriteToCNF(Not(rewriteToCNF(b, nextDepth)), nextDepth)) + case And(a, b) => + And(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)) + case other => other + } + } else { + condition + } + } + + private def maybeWithFilter(joinCondition: Seq[Expression], plan: LogicalPlan) = { + (joinCondition.reduceLeftOption(And).reduceLeftOption(And), plan) match { + case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) => + plan + case (Some(condition), _) => + Filter(condition, plan) + case _ => + plan + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally + + val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { + case j @ Join(left, right, joinType, Some(joinCondition), hint) => + + val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition)) + .filter(_.deterministic) + val (leftEvaluateCondition, rest) = + pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) + val (rightEvaluateCondition, _) = + rest.partition(expr => expr.references.subsetOf(right.outputSet)) + + val newLeft = maybeWithFilter(leftEvaluateCondition, left) + val newRight = maybeWithFilter(rightEvaluateCondition, right) + + joinType match { + case _: InnerLike | LeftSemi => + Join(newLeft, newRight, joinType, Some(joinCondition), hint) + case RightOuter => + Join(newLeft, right, RightOuter, Some(joinCondition), hint) + case LeftOuter | LeftAnti | ExistenceJoin(_) => + Join(left, newRight, joinType, Some(joinCondition), hint) + case FullOuter => j + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c739fa516f0c1..aed5d23d197b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -544,6 +544,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_REWRITING_CNF_DEPTH = + buildConf("spark.sql.maxRewritingCNFDepth") + .internal() + .doc("The maximum depth of rewriting a join condition to conjunctive normal form " + + "expression. The deeper, the more predicate may be found, but the optimization time " + + "will increase. The default is 6. By setting this value to 0 this feature can be disabled.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, + "The depth of the maximum rewriting conjunction normal form must be positive.") + .createWithDefault(6) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2845,6 +2857,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def maxRewritingCNFDepth: Int = getConf(MAX_REWRITING_CNF_DEPTH) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70e29dca46e9e..d8a05279a5df0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -39,7 +39,9 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughNonJoin, BooleanSimplification, PushPredicateThroughJoin, - CollapseProject) :: Nil + CollapseProject) :: + Batch("PushPredicateThroughJoinByCNF", Once, + PushPredicateThroughJoinByCNF) :: Nil } val attrA = 'a.int @@ -1230,4 +1232,106 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), expected) } + + test("inner join: rewrite filter predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && Not(("x.a".attr > 3) + && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13))) + && (("x.a".attr <= 1) || ("y.a".attr <= 11)))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("left join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } } From b38404ce1430a4ed13a620f8d143bb35447d9890 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 20 May 2020 23:51:48 +0800 Subject: [PATCH 2/4] Avoid genereting too many predicates --- .../sql/catalyst/optimizer/Optimizer.scala | 60 +++++++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 20 +++++++ 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f510d68b3b7bf..fb06cc0806ef5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -1380,6 +1382,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { */ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper { + // Used to group same side expressions to avoid generating too many duplicate predicates. + private case class SameSide(exps: Seq[Expression]) extends CodegenFallback { + override def children: Seq[Expression] = exps + override def nullable: Boolean = true + override def dataType: DataType = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + } + /** * Rewrite pattern: * 1. (a && b) || c --> (a || c) && (b || c) @@ -1408,8 +1418,43 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel } } - private def maybeWithFilter(joinCondition: Seq[Expression], plan: LogicalPlan) = { - (joinCondition.reduceLeftOption(And).reduceLeftOption(And), plan) match { + /** + * Split And expression by single side references. For example, + * t1.a > 1 and t1.a < 10 and t2.a < 10 --> + * SameSide(t1.a > 1, t1.a < 10) and SameSide(t2.a < 10) + */ + private def splitAndExp(and: And, outputSet: AttributeSet) = { + val (leftSide, rightSide) = + splitConjunctivePredicates(and).partition(_.references.subsetOf(outputSet)) + Seq(SameSide(leftSide), SameSide(rightSide)).filter(_.exps.nonEmpty).reduceLeft(And) + } + + private def splitCondition(condition: Expression, outputSet: AttributeSet): Expression = { + condition.transformUp { + case Or(a: And, b: And) => + Or(splitAndExp(a, outputSet), splitAndExp(b, outputSet)) + case Or(a: And, b) => + Or(splitAndExp(a, outputSet), b) + case Or(a, b: And) => + Or(a, splitAndExp(b, outputSet)) + } + } + + // Restore expressions from SameSide. + private def restoreExps(condition: Expression): Expression = { + condition match { + case SameSide(exps) => + exps.reduceLeft(And) + case Or(a, b) => + Or(restoreExps(a), restoreExps(b)) + case And(a, b) => + And(restoreExps(a), restoreExps(b)) + case other => other + } + } + + private def maybeWithFilter(joinCondition: Option[Expression], plan: LogicalPlan) = { + (joinCondition, plan) match { case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) => plan case (Some(condition), _) => @@ -1424,15 +1469,18 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { case j @ Join(left, right, joinType, Some(joinCondition), hint) => - val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition)) - .filter(_.deterministic) + val pushDownCandidates = + splitConjunctivePredicates(rewriteToCNF(splitCondition(joinCondition, left.outputSet))) + .filter(_.deterministic) val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, _) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - val newLeft = maybeWithFilter(leftEvaluateCondition, left) - val newRight = maybeWithFilter(rightEvaluateCondition, right) + val newLeft = + maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And).map(restoreExps), left) + val newRight = + maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And).map(restoreExps), right) joinType match { case _: InnerLike | LeftSemi => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index aed5d23d197b1..1d1132fc773e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -554,7 +554,7 @@ object SQLConf { .intConf .checkValue(_ >= 0, "The depth of the maximum rewriting conjunction normal form must be positive.") - .createWithDefault(6) + .createWithDefault(10) val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d8a05279a5df0..e29d182657b9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -1334,4 +1334,24 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("inner join: rewrite to conjunctive normal form avoid genereting too many predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } } From 6c44d64d827f9d10193fd4c54fc2539fb92d9bf7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 1 Jun 2020 00:00:58 +0800 Subject: [PATCH 3/4] Remove SameSide --- .../sql/catalyst/optimizer/Optimizer.scala | 116 ++++++++---------- .../optimizer/FilterPushdownSuite.scala | 29 +++++ 2 files changed, 78 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fb06cc0806ef5..664a2306d1ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,12 +20,10 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -121,6 +119,8 @@ abstract class Optimizer(catalogManager: CatalogManager) InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, rulesWithoutInferFiltersFromConstraints: _*) :: + // Set strategy to Once to avoid pushing filter every time because we do not change the + // join condition. Batch("Push predicate through join by conjunctive normal form", Once, PushPredicateThroughJoinByCNF) :: Nil } @@ -1381,80 +1381,65 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { * more predicate. */ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper { - - // Used to group same side expressions to avoid generating too many duplicate predicates. - private case class SameSide(exps: Seq[Expression]) extends CodegenFallback { - override def children: Seq[Expression] = exps - override def nullable: Boolean = true - override def dataType: DataType = throw new UnsupportedOperationException - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException - } - /** * Rewrite pattern: * 1. (a && b) || c --> (a || c) && (b || c) * 2. a || (b && c) --> (a || b) && (a || c) - * 3. !(a || b) --> !a && !b + * + * To avoid generating too many predicates, we first group the filter columns from the same table. */ - private def rewriteToCNF(condition: Expression, depth: Int = 0): Expression = { + private def toCNF(condition: Expression, depth: Int = 0): Expression = { if (depth < SQLConf.get.maxRewritingCNFDepth) { - val nextDepth = depth + 1 condition match { - case Or(And(a, b), c) => - And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth), - rewriteToCNF(Or(rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth)) - case Or(a, And(b, c)) => - And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth), - rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth)) - case Not(Or(a, b)) => - And(rewriteToCNF(Not(rewriteToCNF(a, nextDepth)), nextDepth), - rewriteToCNF(Not(rewriteToCNF(b, nextDepth)), nextDepth)) - case And(a, b) => - And(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)) - case other => other - } - } else { - condition - } - } + case or @ Or(left: And, right: And) => + val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier)) + val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier)) + if (lhs.size > 1) { + lhs.values.map(_.reduceLeft(And)).map { c => + toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1) + }.reduce(And) + } else if (rhs.size > 1) { + rhs.values.map(_.reduceLeft(And)).map { c => + toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1) + }.reduce(And) + } else { + or + } - /** - * Split And expression by single side references. For example, - * t1.a > 1 and t1.a < 10 and t2.a < 10 --> - * SameSide(t1.a > 1, t1.a < 10) and SameSide(t2.a < 10) - */ - private def splitAndExp(and: And, outputSet: AttributeSet) = { - val (leftSide, rightSide) = - splitConjunctivePredicates(and).partition(_.references.subsetOf(outputSet)) - Seq(SameSide(leftSide), SameSide(rightSide)).filter(_.exps.nonEmpty).reduceLeft(And) - } + case or @ Or(left: And, right) => + val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier)) + if (lhs.size > 1) { + lhs.values.map(_.reduceLeft(And)).map { + c => toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1) + }.reduce(And) + } else { + or + } - private def splitCondition(condition: Expression, outputSet: AttributeSet): Expression = { - condition.transformUp { - case Or(a: And, b: And) => - Or(splitAndExp(a, outputSet), splitAndExp(b, outputSet)) - case Or(a: And, b) => - Or(splitAndExp(a, outputSet), b) - case Or(a, b: And) => - Or(a, splitAndExp(b, outputSet)) - } - } + case or @ Or(left, right: And) => + val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier)) + if (rhs.size > 1) { + rhs.values.map(_.reduceLeft(And)).map { c => + toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1) + }.reduce(And) + } else { + or + } - // Restore expressions from SameSide. - private def restoreExps(condition: Expression): Expression = { - condition match { - case SameSide(exps) => - exps.reduceLeft(And) - case Or(a, b) => - Or(restoreExps(a), restoreExps(b)) - case And(a, b) => - And(restoreExps(a), restoreExps(b)) - case other => other + case And(left, right) => + And(toCNF(left, depth + 1), toCNF(right, depth + 1)) + + case other => + other + } + } else { + condition } } private def maybeWithFilter(joinCondition: Option[Expression], plan: LogicalPlan) = { (joinCondition, plan) match { + // Avoid adding the same filter. case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) => plan case (Some(condition), _) => @@ -1470,17 +1455,14 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel case j @ Join(left, right, joinType, Some(joinCondition), hint) => val pushDownCandidates = - splitConjunctivePredicates(rewriteToCNF(splitCondition(joinCondition, left.outputSet))) - .filter(_.deterministic) + splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic) val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, _) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - val newLeft = - maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And).map(restoreExps), left) - val newRight = - maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And).map(restoreExps), right) + val newLeft = maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And), left) + val newRight = maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And), right) joinType match { case _: InnerLike | LeftSemi => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index e29d182657b9f..66148017f5989 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.unsafe.types.CalendarInterval @@ -1354,4 +1355,32 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_REWRITING_CNF_DEPTH.key}=0") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + Seq(0, 10).foreach { depth => + withSQLConf(SQLConf.MAX_REWRITING_CNF_DEPTH.key -> depth.toString) { + val optimized = Optimize.execute(originalQuery.analyze) + val (left, right) = if (depth == 0) { + (testRelation.subquery('x), testRelation.subquery('y)) + } else { + (testRelation.subquery('x), + testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) + } + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + } + } } From e8f24714d90ef7a2fbcd2458463d91c05e28ebbf Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 3 Jun 2020 18:40:27 +0800 Subject: [PATCH 4/4] Remove maybeWithFilter and add PushPredicateThroughJoinByCNF to blacklistedOnceBatches --- .../sql/catalyst/optimizer/Optimizer.scala | 55 ++++++++----------- .../apache/spark/sql/internal/SQLConf.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 8 ++- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 664a2306d1ba1..04a89760b9e9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val blacklistedOnceBatches: Set[String] = Set( "PartitionPruning", - "Extract Python UDFs") + "Extract Python UDFs", + "Push predicate through join by CNF") protected def fixedPoint = FixedPoint( @@ -121,7 +122,7 @@ abstract class Optimizer(catalogManager: CatalogManager) rulesWithoutInferFiltersFromConstraints: _*) :: // Set strategy to Once to avoid pushing filter every time because we do not change the // join condition. - Batch("Push predicate through join by conjunctive normal form", Once, + Batch("Push predicate through join by CNF", Once, PushPredicateThroughJoinByCNF) :: Nil } @@ -1383,10 +1384,11 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrite pattern: - * 1. (a && b) || c --> (a || c) && (b || c) - * 2. a || (b && c) --> (a || b) && (a || c) + * 1. (a && b) || (c && d) --> (a || c) && (a || d) && (b || c) && (b && d) + * 2. (a && b) || c --> (a || c) && (b || c) + * 3. a || (b && c) --> (a || b) && (a || c) * - * To avoid generating too many predicates, we first group the filter columns from the same table. + * To avoid generating too many predicates, we first group the columns from the same table. */ private def toCNF(condition: Expression, depth: Int = 0): Expression = { if (depth < SQLConf.get.maxRewritingCNFDepth) { @@ -1395,12 +1397,12 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier)) val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier)) if (lhs.size > 1) { - lhs.values.map(_.reduceLeft(And)).map { c => - toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1) + lhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1) }.reduce(And) } else if (rhs.size > 1) { - rhs.values.map(_.reduceLeft(And)).map { c => - toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1) + rhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1) }.reduce(And) } else { or @@ -1409,8 +1411,8 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel case or @ Or(left: And, right) => val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier)) if (lhs.size > 1) { - lhs.values.map(_.reduceLeft(And)).map { - c => toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1) + lhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1) }.reduce(And) } else { or @@ -1419,8 +1421,8 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel case or @ Or(left, right: And) => val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier)) if (rhs.size > 1) { - rhs.values.map(_.reduceLeft(And)).map { c => - toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1) + rhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1) }.reduce(And) } else { or @@ -1437,32 +1439,19 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel } } - private def maybeWithFilter(joinCondition: Option[Expression], plan: LogicalPlan) = { - (joinCondition, plan) match { - // Avoid adding the same filter. - case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) => - plan - case (Some(condition), _) => - Filter(condition, plan) - case _ => - plan - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally - - val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) => - val pushDownCandidates = splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic) - val (leftEvaluateCondition, rest) = + val (leftFilterConditions, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) - val (rightEvaluateCondition, _) = + val (rightFilterConditions, _) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - val newLeft = maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And), left) - val newRight = maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And), right) + val newLeft = leftFilterConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = rightFilterConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) joinType match { case _: InnerLike | LeftSemi => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1d1132fc773e6..d8ac3943d2336 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -548,8 +548,8 @@ object SQLConf { buildConf("spark.sql.maxRewritingCNFDepth") .internal() .doc("The maximum depth of rewriting a join condition to conjunctive normal form " + - "expression. The deeper, the more predicate may be found, but the optimization time " + - "will increase. The default is 6. By setting this value to 0 this feature can be disabled.") + "expression. The deeper, the more predicate may be found, but the optimization time will " + + "increase. The default is 10. By setting this value to 0 this feature can be disabled.") .version("3.1.0") .intConf .checkValue(_ >= 0, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 66148017f5989..d3c338c5789dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -32,6 +32,10 @@ import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { + + override protected val blacklistedOnceBatches: Set[String] = + Set("Push predicate through join by CNF") + val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: @@ -41,7 +45,7 @@ class FilterPushdownSuite extends PlanTest { BooleanSimplification, PushPredicateThroughJoin, CollapseProject) :: - Batch("PushPredicateThroughJoinByCNF", Once, + Batch("Push predicate through join by CNF", Once, PushPredicateThroughJoinByCNF) :: Nil } @@ -1336,7 +1340,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join: rewrite to conjunctive normal form avoid genereting too many predicates") { + test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { val x = testRelation.subquery('x) val y = testRelation.subquery('y)