From 3356bac5cb04ba25c9f587a02f63bcfd885b34ff Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 8 Jun 2020 11:20:34 +0800 Subject: [PATCH 01/26] WIP --- .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../PushCNFPredicateThroughScan.scala | 128 ++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 15 ++ .../sql/hive/execution/SQLQuerySuite.scala | 34 +++++ 4 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f1a307b1c2cc1..2af8a30a83d18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,7 +118,11 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: + // Set strategy to Once to avoid pushing filter every time because we do not change the + // join condition. + Batch("Push CNF predicate through join", Once, + PushCNFPredicateThroughScan) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala new file mode 100644 index 0000000000000..6e89b5208cf7e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import javax.management.relation.Relation + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions.{And, Expression, Not, Or, PredicateHelper} +import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Try converting join condition to conjunctive normal form expression so that more predicates may + * be able to be pushed down. + * To avoid expanding the join condition, the join condition will be kept in the original form even + * when predicate pushdown happens. + */ +object PushCNFPredicateThroughScan extends Rule[LogicalPlan] with PredicateHelper { + /** + * Convert an expression into conjunctive normal form. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting Or clauses. + * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. + * + * @param condition to be conversed into CNF. + * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. + * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. + * Otherwise, return the converted result as sequence of disjunctive expressions. + */ + protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new scala.collection.mutable.Stack[Seq[Expression]] + // val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + val maxCnfNodeCount = 100 + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right: Seq[Expression] = resultStack.pop() + val left: Seq[Expression] = resultStack.pop() + left ++ right + case _: Or => + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same references as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = aggregateExpressionsOfSameReference(resultStack.pop()) + val left = aggregateExpressionsOfSameReference(resultStack.pop()) + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + Seq.empty + } else { + for {x <- left; y <- right} yield Or(x, y) + } + case other => other :: Nil + } + if (cnf.isEmpty) { + return Seq.empty + } + resultStack.push(cnf) + } + assert(resultStack.length == 1, + s"Fail to convert expression ${condition} to conjunctive normal form") + resultStack.top + } + + private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq + } + + /** + * Iterative post order traversal over a binary tree built by And/Or clauses. + * + * @param condition to be traversed as binary tree + * @return sub-expressions in post order traversal as an Array. + * The first element of result Array is the leftmost node. + */ + private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { + val stack = new mutable.Stack[Expression] + val result = new mutable.Stack[Expression] + stack.push(condition) + while (stack.nonEmpty) { + val node = stack.pop() + node match { + case Not(a And b) => stack.push(Or(Not(a), Not(b))) + case Not(a Or b) => stack.push(And(Not(a), Not(b))) + case Not(Not(a)) => stack.push(a) + case a And b => + result.push(node) + stack.push(a) + stack.push(b) + case a Or b => + result.push(node) + stack.push(a) + stack.push(b) + case _ => + result.push(node) + } + } + result + } + + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case ScanOperation(projectList, conditions, relation: HiveTableRelation) => + val predicates = conjunctiveNormalForm(conditions.reduceLeft(And)) + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f7a904169d6c3..4875ea605143e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3503,6 +3503,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"), Seq(Row(Short.MinValue.toLong * -1))) } + + test("test") { + withTable("t") { + sql( + """ + |create table t(id int, dt string) using orc partitioned by (dt) + """.stripMargin) + + sql( + """ + |select * from t where dt = '1' or (dt = '2' and id = 1) + """.stripMargin).explain(true) + } + + } } case class Foo(bar: Option[String]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 79c6ade2807d3..e600368bf434c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2544,6 +2544,40 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi assert(e.getMessage.contains("Cannot modify the value of a static config")) } } + + test("test") { + Seq("false").foreach { convertParquet => + withTable("t") { + withTempDir { f => + sql("CREATE EXTERNAL TABLE t(id int) PARTITIONED BY (dt string) STORED AS " + + s"PARQUET LOCATION '${f.getAbsolutePath}'") + + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> convertParquet) { + sql( + """ + |SELECT * FROM t WHERE dt = '1' OR (dt = '2' AND id = 1) + """.stripMargin).explain(true) + + sql( + """ + |SELECT * FROM t WHERE (dt = '20190624' and id = 2) or (id = 1 or dt = '20190625') + """.stripMargin).explain(true) + + sql( + """ + |SELECT * FROM t WHERE (dt = '20190624' and id = 2) or (dt = '20190625' and id = 3 ); + """.stripMargin).explain(true) + + sql( + """ + |SELECT * FROM t WHERE (dt = '20190624' and id = 2) or (dt = '20190630' or dt = '20190625') + """.stripMargin).explain(true) + } + } + } + } + } + } class SQLQuerySuite extends SQLQuerySuiteBase with DisableAdaptiveExecutionSuite From 346a1b4e36e72735ce3393061e7724feb14c0121 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 8 Jun 2020 19:20:46 +0800 Subject: [PATCH 02/26] save --- .../sql/catalyst/expressions/predicates.scala | 135 ++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../PushCNFPredicateThroughScan.scala | 128 ----------------- .../apache/spark/sql/internal/SQLConf.scala | 15 ++ .../PushCNFPredicateThroughFileScan.scala | 42 ++++++ .../spark/sql/execution/SparkOptimizer.scala | 3 +- .../sql/hive/HiveSessionStateBuilder.scala | 4 +- ...PushCNFPredicateThroughHiveTableScan.scala | 41 ++++++ .../hive/execution/HiveTableScanSuite.scala | 38 +++++ .../sql/hive/execution/SQLQuerySuite.scala | 34 ----- 10 files changed, 276 insertions(+), 170 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2c4f41f98ac20..679c61a34fc92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.TreeSet +import scala.collection.mutable import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.InternalRow @@ -198,6 +199,140 @@ trait PredicateHelper { case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } + + + /** + * Convert an expression into conjunctive normal form. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting Or clauses. + * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. + * + * @param condition to be conversed into CNF. + * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. + * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. + * Otherwise, return the converted result as sequence of disjunctive expressions. + */ + def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new mutable.Stack[Seq[Expression]] + val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right: Seq[Expression] = resultStack.pop() + val left: Seq[Expression] = resultStack.pop() + left ++ right + case _: Or => + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same references as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = aggregateExpressionsOfSameQualifiers(resultStack.pop()) + val left = aggregateExpressionsOfSameQualifiers(resultStack.pop()) + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + Seq.empty + } else { + for {x <- left; y <- right} yield Or(x, y) + } + case other => other :: Nil + } + if (cnf.isEmpty) { + return Seq.empty + } + resultStack.push(cnf) + } + assert(resultStack.length == 1, + s"Fail to convert expression ${condition} to conjunctive normal form") + resultStack.top + } + + /** + * Convert an expression into conjunctive normal form. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting Or clauses. + * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. + * + * @param condition to be conversed into CNF. + * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. + * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. + * Otherwise, return the converted result as sequence of disjunctive expressions. + */ + def conjunctiveNormalFormForPartitionPruning(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new mutable.Stack[Seq[Expression]] + val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right: Seq[Expression] = resultStack.pop() + val left: Seq[Expression] = resultStack.pop() + left ++ right + case _: Or => + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same references as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = aggregateExpressionsOfSameReference(resultStack.pop()) + val left = aggregateExpressionsOfSameReference(resultStack.pop()) + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + Seq.empty + } else { + for {x <- left; y <- right} yield Or(x, y) + } + case other => other :: Nil + } + if (cnf.isEmpty) { + return Seq.empty + } + resultStack.push(cnf) + } + assert(resultStack.length == 1, + s"Fail to convert expression ${condition} to conjunctive normal form") + resultStack.top + } + + private def aggregateExpressionsOfSameQualifiers( + expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq + } + + private def aggregateExpressionsOfSameReference( + expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq + } + + /** + * Iterative post order traversal over a binary tree built by And/Or clauses. + * @param condition to be traversed as binary tree + * @return sub-expressions in post order traversal as an Array. + * The first element of result Array is the leftmost node. + */ + private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { + val stack = new mutable.Stack[Expression] + val result = new mutable.Stack[Expression] + stack.push(condition) + while (stack.nonEmpty) { + val node = stack.pop() + node match { + case Not(a And b) => stack.push(Or(Not(a), Not(b))) + case Not(a Or b) => stack.push(And(Not(a), Not(b))) + case Not(Not(a)) => stack.push(a) + case a And b => + result.push(node) + stack.push(a) + stack.push(b) + case a Or b => + result.push(node) + stack.push(a) + stack.push(b) + case _ => + result.push(node) + } + } + result + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2af8a30a83d18..f1a307b1c2cc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,11 +118,7 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: - // Set strategy to Once to avoid pushing filter every time because we do not change the - // join condition. - Batch("Push CNF predicate through join", Once, - PushCNFPredicateThroughScan) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala deleted file mode 100644 index 6e89b5208cf7e..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughScan.scala +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import javax.management.relation.Relation - -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.catalog.HiveTableRelation -import org.apache.spark.sql.catalyst.expressions.{And, Expression, Not, Or, PredicateHelper} -import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Try converting join condition to conjunctive normal form expression so that more predicates may - * be able to be pushed down. - * To avoid expanding the join condition, the join condition will be kept in the original form even - * when predicate pushdown happens. - */ -object PushCNFPredicateThroughScan extends Rule[LogicalPlan] with PredicateHelper { - /** - * Convert an expression into conjunctive normal form. - * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting Or clauses. - * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. - * - * @param condition to be conversed into CNF. - * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. - * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. - * Otherwise, return the converted result as sequence of disjunctive expressions. - */ - protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { - val postOrderNodes = postOrderTraversal(condition) - val resultStack = new scala.collection.mutable.Stack[Seq[Expression]] - // val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount - val maxCnfNodeCount = 100 - // Bottom up approach to get CNF of sub-expressions - while (postOrderNodes.nonEmpty) { - val cnf = postOrderNodes.pop() match { - case _: And => - val right: Seq[Expression] = resultStack.pop() - val left: Seq[Expression] = resultStack.pop() - left ++ right - case _: Or => - // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same references as one single predicate, - // for reducing the size of pushed down predicates and corresponding codegen. - val right = aggregateExpressionsOfSameReference(resultStack.pop()) - val left = aggregateExpressionsOfSameReference(resultStack.pop()) - // Stop the loop whenever the result exceeds the `maxCnfNodeCount` - if (left.size * right.size > maxCnfNodeCount) { - Seq.empty - } else { - for {x <- left; y <- right} yield Or(x, y) - } - case other => other :: Nil - } - if (cnf.isEmpty) { - return Seq.empty - } - resultStack.push(cnf) - } - assert(resultStack.length == 1, - s"Fail to convert expression ${condition} to conjunctive normal form") - resultStack.top - } - - private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq - } - - /** - * Iterative post order traversal over a binary tree built by And/Or clauses. - * - * @param condition to be traversed as binary tree - * @return sub-expressions in post order traversal as an Array. - * The first element of result Array is the leftmost node. - */ - private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { - val stack = new mutable.Stack[Expression] - val result = new mutable.Stack[Expression] - stack.push(condition) - while (stack.nonEmpty) { - val node = stack.pop() - node match { - case Not(a And b) => stack.push(Or(Not(a), Not(b))) - case Not(a Or b) => stack.push(And(Not(a), Not(b))) - case Not(Not(a)) => stack.push(a) - case a And b => - result.push(node) - stack.push(a) - stack.push(b) - case a Or b => - result.push(node) - stack.push(a) - stack.push(b) - case _ => - result.push(node) - } - } - result - } - - - def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { - case ScanOperation(projectList, conditions, relation: HiveTableRelation) => - val predicates = conjunctiveNormalForm(conditions.reduceLeft(And)) - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3a41b0553db54..2294213da7deb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -545,6 +545,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_CNF_NODE_COUNT = + buildConf("spark.sql.optimizer.maxCNFNodeCount") + .internal() + .doc("Specifies the maximum allowable number of conjuncts in the result of CNF " + + "conversion. If the conversion exceeds the threshold, None is returned. " + + "For example, CNF conversion of (a && b) || (c && d) generates " + + "four conjuncts (a || c) && (a || d) && (b || c) && (b || d).") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, + "The depth of the maximum rewriting conjunction normal form must be positive.") + .createWithDefault(128) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2871,6 +2884,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT) + def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala new file mode 100644 index 0000000000000..ab9d0769ba23e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.LogicalRelation + +/** + * Try converting join condition to conjunctive normal form expression so that more predicates may + * be able to be pushed down. + * To avoid expanding the join condition, the join condition will be kept in the original form even + * when predicate pushdown happens. + */ +object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case ScanOperation(projectList, conditions, relation: LogicalRelation) => + val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 33b86a2b5340c..4c8ae8d48437a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -37,7 +37,8 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil + SchemaPruning :: V2ScanRelationPushDown :: PushCNFPredicateThroughFileScan :: + PruneFileSourcePartitions :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 64726755237a6..0c7ad54dc94b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.TableCapabilityCheck import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.hive.execution.PruneHiveTablePartitions +import org.apache.spark.sql.hive.execution.{PruneHiveTablePartitions, PushCNFPredicateThroughHiveTableScan} import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** @@ -99,7 +99,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session } override def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = - Seq(new PruneHiveTablePartitions(session)) + Seq(PushCNFPredicateThroughHiveTableScan, new PruneHiveTablePartitions(session)) /** * Planner that takes into account Hive-specific strategies. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala new file mode 100644 index 0000000000000..48152acda21f5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Try converting join condition to conjunctive normal form expression so that more predicates may + * be able to be pushed down. + * To avoid expanding the join condition, the join condition will be kept in the original form even + * when predicate pushdown happens. + */ +object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case ScanOperation(projectList, conditions, relation: HiveTableRelation) => + val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 67d7ed0841abb..954d7351494d2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -187,6 +189,42 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } } + test("Convert scan predicate to CNF") { + withTable("t", "temp") { + sql( + s""" + |CREATE TABLE t(i int) + |PARTITIONED BY (p int) + |STORED AS textfile""".stripMargin) + spark.range(0, 1000, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Seq(1, 2, 3, 4)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p='$part') + |select col from temp""".stripMargin) + } + + val scan1 = getHiveTableScanExec("SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)") + val scan2 = getHiveTableScanExec( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')") + val scan3 = getHiveTableScanExec( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )") + val scan4 = getHiveTableScanExec( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')") + + assert(scan1.prunedPartitions.map(_.toString) == + Stream("t(p=1)", "t(p=2)")) + assert(scan2.prunedPartitions.map(_.toString) == + Stream("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) + assert(scan3.prunedPartitions.map(_.toString) == + Stream("t(p=1)", "t(p=3)")) + assert(scan4.prunedPartitions.map(_.toString) == + Stream("t(p=1)", "t(p=2)", "t(p=3)")) + } + } + private def getHiveTableScanExec(query: String): HiveTableScanExec = { sql(query).queryExecution.sparkPlan.collectFirst { case p: HiveTableScanExec => p diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index e600368bf434c..79c6ade2807d3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2544,40 +2544,6 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi assert(e.getMessage.contains("Cannot modify the value of a static config")) } } - - test("test") { - Seq("false").foreach { convertParquet => - withTable("t") { - withTempDir { f => - sql("CREATE EXTERNAL TABLE t(id int) PARTITIONED BY (dt string) STORED AS " + - s"PARQUET LOCATION '${f.getAbsolutePath}'") - - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> convertParquet) { - sql( - """ - |SELECT * FROM t WHERE dt = '1' OR (dt = '2' AND id = 1) - """.stripMargin).explain(true) - - sql( - """ - |SELECT * FROM t WHERE (dt = '20190624' and id = 2) or (id = 1 or dt = '20190625') - """.stripMargin).explain(true) - - sql( - """ - |SELECT * FROM t WHERE (dt = '20190624' and id = 2) or (dt = '20190625' and id = 3 ); - """.stripMargin).explain(true) - - sql( - """ - |SELECT * FROM t WHERE (dt = '20190624' and id = 2) or (dt = '20190630' or dt = '20190625') - """.stripMargin).explain(true) - } - } - } - } - } - } class SQLQuerySuite extends SQLQuerySuiteBase with DisableAdaptiveExecutionSuite From 250c7b386ee4352ecc33c5905c2d741008936d15 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 9 Jun 2020 09:57:58 +0800 Subject: [PATCH 03/26] Update HiveTableScanSuite.scala --- .../apache/spark/sql/hive/execution/HiveTableScanSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 954d7351494d2..ae2afa817d9c3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -206,7 +206,8 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH |select col from temp""".stripMargin) } - val scan1 = getHiveTableScanExec("SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)") + val scan1 = getHiveTableScanExec( + "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)") val scan2 = getHiveTableScanExec( "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')") val scan3 = getHiveTableScanExec( From 15d62be091a42602181c100f6c31d597691a9285 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 09:37:06 +0800 Subject: [PATCH 04/26] save --- .../sql/catalyst/expressions/predicates.scala | 134 ------------------ 1 file changed, 134 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 679c61a34fc92..9c1f63322dd09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -199,140 +199,6 @@ trait PredicateHelper { case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } - - - /** - * Convert an expression into conjunctive normal form. - * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting Or clauses. - * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. - * - * @param condition to be conversed into CNF. - * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. - * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. - * Otherwise, return the converted result as sequence of disjunctive expressions. - */ - def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { - val postOrderNodes = postOrderTraversal(condition) - val resultStack = new mutable.Stack[Seq[Expression]] - val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount - // Bottom up approach to get CNF of sub-expressions - while (postOrderNodes.nonEmpty) { - val cnf = postOrderNodes.pop() match { - case _: And => - val right: Seq[Expression] = resultStack.pop() - val left: Seq[Expression] = resultStack.pop() - left ++ right - case _: Or => - // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same references as one single predicate, - // for reducing the size of pushed down predicates and corresponding codegen. - val right = aggregateExpressionsOfSameQualifiers(resultStack.pop()) - val left = aggregateExpressionsOfSameQualifiers(resultStack.pop()) - // Stop the loop whenever the result exceeds the `maxCnfNodeCount` - if (left.size * right.size > maxCnfNodeCount) { - Seq.empty - } else { - for {x <- left; y <- right} yield Or(x, y) - } - case other => other :: Nil - } - if (cnf.isEmpty) { - return Seq.empty - } - resultStack.push(cnf) - } - assert(resultStack.length == 1, - s"Fail to convert expression ${condition} to conjunctive normal form") - resultStack.top - } - - /** - * Convert an expression into conjunctive normal form. - * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting Or clauses. - * Use a configuration MAX_CNF_NODE_COUNT to prevent such cases. - * - * @param condition to be conversed into CNF. - * @return If the number of expressions exceeds threshold on converting Or, return Seq.empty. - * If the conversion repeatedly expands nondeterministic expressions, return Seq.empty. - * Otherwise, return the converted result as sequence of disjunctive expressions. - */ - def conjunctiveNormalFormForPartitionPruning(condition: Expression): Seq[Expression] = { - val postOrderNodes = postOrderTraversal(condition) - val resultStack = new mutable.Stack[Seq[Expression]] - val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount - // Bottom up approach to get CNF of sub-expressions - while (postOrderNodes.nonEmpty) { - val cnf = postOrderNodes.pop() match { - case _: And => - val right: Seq[Expression] = resultStack.pop() - val left: Seq[Expression] = resultStack.pop() - left ++ right - case _: Or => - // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same references as one single predicate, - // for reducing the size of pushed down predicates and corresponding codegen. - val right = aggregateExpressionsOfSameReference(resultStack.pop()) - val left = aggregateExpressionsOfSameReference(resultStack.pop()) - // Stop the loop whenever the result exceeds the `maxCnfNodeCount` - if (left.size * right.size > maxCnfNodeCount) { - Seq.empty - } else { - for {x <- left; y <- right} yield Or(x, y) - } - case other => other :: Nil - } - if (cnf.isEmpty) { - return Seq.empty - } - resultStack.push(cnf) - } - assert(resultStack.length == 1, - s"Fail to convert expression ${condition} to conjunctive normal form") - resultStack.top - } - - private def aggregateExpressionsOfSameQualifiers( - expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq - } - - private def aggregateExpressionsOfSameReference( - expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq - } - - /** - * Iterative post order traversal over a binary tree built by And/Or clauses. - * @param condition to be traversed as binary tree - * @return sub-expressions in post order traversal as an Array. - * The first element of result Array is the leftmost node. - */ - private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { - val stack = new mutable.Stack[Expression] - val result = new mutable.Stack[Expression] - stack.push(condition) - while (stack.nonEmpty) { - val node = stack.pop() - node match { - case Not(a And b) => stack.push(Or(Not(a), Not(b))) - case Not(a Or b) => stack.push(And(Not(a), Not(b))) - case Not(Not(a)) => stack.push(a) - case a And b => - result.push(node) - stack.push(a) - stack.push(b) - case a Or b => - result.push(node) - stack.push(a) - stack.push(b) - case _ => - result.push(node) - } - } - result - } } @ExpressionDescription( From d8f7c9e19d810b0eecad8af16234e2cee8f14561 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 09:41:28 +0800 Subject: [PATCH 05/26] save --- .../sql/catalyst/expressions/predicates.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c9b57367e0f44..89bad7e42ad56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -253,6 +253,58 @@ trait PredicateHelper extends Logging { expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq } + /** + * Convert an expression into conjunctive normal form for partition pruning. + * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form + * CNF can explode exponentially in the size of the input expression when converting [[Or]] + * clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases. + * + * @param condition to be converted into CNF. + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. + */ + def conjunctiveNormalFormForPartitionPruning(condition: Expression): Seq[Expression] = { + val postOrderNodes = postOrderTraversal(condition) + val resultStack = new mutable.Stack[Seq[Expression]] + val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount + // Bottom up approach to get CNF of sub-expressions + while (postOrderNodes.nonEmpty) { + val cnf = postOrderNodes.pop() match { + case _: And => + val right = resultStack.pop() + val left = resultStack.pop() + left ++ right + case _: Or => + // For each side, there is no need to expand predicates of the same references. + // So here we can aggregate predicates of the same qualifier as one single predicate, + // for reducing the size of pushed down predicates and corresponding codegen. + val right = groupExpressionsByQualifier(resultStack.pop()) + val left = groupExpressionsByQualifier(resultStack.pop()) + // Stop the loop whenever the result exceeds the `maxCnfNodeCount` + if (left.size * right.size > maxCnfNodeCount) { + logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + + "The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " + + s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.") + return Seq.empty + } else { + for { x <- left; y <- right } yield Or(x, y) + } + case other => other :: Nil + } + resultStack.push(cnf) + } + if (resultStack.length != 1) { + logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + + "be something wrong with CNF conversion.") + return Seq.empty + } + resultStack.top + } + + private def groupExpressionsByReference(expressions: Seq[Expression]): Seq[Expression] = { + expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq + } + /** * Iterative post order traversal over a binary tree built by And/Or clauses with two stacks. * For example, a condition `(a And b) Or c`, the postorder traversal is From 8856453b4089da0f4d64845c3c94cb251413038b Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 09:43:15 +0800 Subject: [PATCH 06/26] Update SQLConf.scala --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3044ceada0cb3..33f40b47d072b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2891,8 +2891,6 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) - def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT) - def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) From 697a3a93bbe7f06807507fafa80708efaca239d8 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 10:36:06 +0800 Subject: [PATCH 07/26] Update HiveTableScanSuite.scala --- .../apache/spark/sql/hive/execution/HiveTableScanSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index ae2afa817d9c3..44ba851d7d907 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -189,7 +189,7 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } } - test("Convert scan predicate to CNF") { + test("SPARK-28169: Convert scan predicate condition to CNF") { withTable("t", "temp") { sql( s""" From 7e8319e1c4c40e69fcc964e2306236c28534269a Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 11:01:23 +0800 Subject: [PATCH 08/26] Update predicates.scala --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 89bad7e42ad56..3b68dca807edd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -278,8 +278,8 @@ trait PredicateHelper extends Logging { // For each side, there is no need to expand predicates of the same references. // So here we can aggregate predicates of the same qualifier as one single predicate, // for reducing the size of pushed down predicates and corresponding codegen. - val right = groupExpressionsByQualifier(resultStack.pop()) - val left = groupExpressionsByQualifier(resultStack.pop()) + val right = groupExpressionsByReference(resultStack.pop()) + val left = groupExpressionsByReference(resultStack.pop()) // Stop the loop whenever the result exceeds the `maxCnfNodeCount` if (left.size * right.size > maxCnfNodeCount) { logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + From 373486615f9d2e44f01bd5eba11c4ee4628df1b1 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 11:08:50 +0800 Subject: [PATCH 09/26] empty safe --- .../execution/PushCNFPredicateThroughFileScan.scala | 10 ++++++++-- .../PushCNFPredicateThroughHiveTableScan.scala | 9 +++++++-- .../spark/sql/hive/execution/HiveTableScanSuite.scala | 9 +++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala index ab9d0769ba23e..76ccd021bd6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -34,9 +34,15 @@ object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateH def apply(plan: LogicalPlan): LogicalPlan = { plan transform { - case ScanOperation(projectList, conditions, relation: LogicalRelation) => + case ScanOperation(projectList, conditions, relation: LogicalRelation) + if conditions.nonEmpty => val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + if (predicates.isEmpty) { + return plan + } else { + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala index 48152acda21f5..c7c585f05ad4a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -33,9 +33,14 @@ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with Predi def apply(plan: LogicalPlan): LogicalPlan = { plan transform { - case ScanOperation(projectList, conditions, relation: HiveTableRelation) => + case ScanOperation(projectList, conditions, relation: HiveTableRelation) + if conditions.nonEmpty => val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + if (predicates.isEmpty) { + return plan + } else { + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 44ba851d7d907..6b1f84f64f606 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -214,6 +214,11 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )") val scan4 = getHiveTableScanExec( "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')") + val scan5 = getHiveTableScanExec( + "SELECT * FROM t") + val scan6 = getHiveTableScanExec( + "SELECT * FROM t where p = '1' and i = 2") + assert(scan1.prunedPartitions.map(_.toString) == Stream("t(p=1)", "t(p=2)")) @@ -223,6 +228,10 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH Stream("t(p=1)", "t(p=3)")) assert(scan4.prunedPartitions.map(_.toString) == Stream("t(p=1)", "t(p=2)", "t(p=3)")) + assert(scan5.prunedPartitions.map(_.toString) == + Stream("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) + assert(scan6.prunedPartitions.map(_.toString) == + Stream("t(p=1)")) } } From b253af3acc3275ea4a3b5e2cc1916f303d8bede4 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 11:25:32 +0800 Subject: [PATCH 10/26] save --- .../PushCNFPredicateThroughFileScan.scala | 21 ++++++++----------- ...PushCNFPredicateThroughHiveTableScan.scala | 20 ++++++++---------- .../hive/execution/HiveTableScanSuite.scala | 8 +++++++ 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala index 76ccd021bd6e0..bc045ab33f3ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -32,17 +32,14 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation */ object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { - case ScanOperation(projectList, conditions, relation: LogicalRelation) - if conditions.nonEmpty => - val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) - if (predicates.isEmpty) { - return plan - } else { - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } - - } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ScanOperation(projectList, conditions, relation: LogicalRelation) + if conditions.nonEmpty => + val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) + if (predicates.isEmpty) { + return plan + } else { + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala index c7c585f05ad4a..ddf9db5181c56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -31,16 +31,14 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { - case ScanOperation(projectList, conditions, relation: HiveTableRelation) - if conditions.nonEmpty => - val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) - if (predicates.isEmpty) { - return plan - } else { - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } - } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ScanOperation(projectList, conditions, relation: HiveTableRelation) + if conditions.nonEmpty => + val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) + if (predicates.isEmpty) { + return plan + } else { + return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 6b1f84f64f606..053e7591f521f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -219,6 +219,12 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH val scan6 = getHiveTableScanExec( "SELECT * FROM t where p = '1' and i = 2") + val scan7 = getHiveTableScanExec( + """ + |SELECT i, COUNT(1) FROM ( + |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) + |) TMP GROUP BY i + """.stripMargin) assert(scan1.prunedPartitions.map(_.toString) == Stream("t(p=1)", "t(p=2)")) @@ -232,6 +238,8 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH Stream("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) assert(scan6.prunedPartitions.map(_.toString) == Stream("t(p=1)")) + assert(scan7.prunedPartitions.map(_.toString) == + Stream("t(p=1)", "t(p=2)")) } } From 478a7a80266d8a80628de9e111993b1855cb0d2d Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 12:57:44 +0800 Subject: [PATCH 11/26] fix bug --- .../sql/execution/PushCNFPredicateThroughFileScan.scala | 6 +++--- .../execution/PushCNFPredicateThroughHiveTableScan.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala index bc045ab33f3ba..80eec0f5d4df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -32,14 +32,14 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation */ object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case ScanOperation(projectList, conditions, relation: LogicalRelation) if conditions.nonEmpty => val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) if (predicates.isEmpty) { - return plan + plan } else { - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + Project(projectList, Filter(predicates.reduceLeft(And), relation)) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala index ddf9db5181c56..0ab4795541509 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -31,14 +31,14 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case ScanOperation(projectList, conditions, relation: HiveTableRelation) if conditions.nonEmpty => val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) if (predicates.isEmpty) { - return plan + plan } else { - return Project(projectList, Filter(predicates.reduceLeft(And), relation)) + Project(projectList, Filter(predicates.reduceLeft(And), relation)) } } } From 603660b724e9d011f26fb4aa8e688f4a8e60f37b Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 15:34:27 +0800 Subject: [PATCH 12/26] save --- .../sql/catalyst/expressions/predicates.scala | 68 ++++--------------- .../PushCNFPredicateThroughJoin.scala | 2 +- .../ConjunctiveNormalFormPredicateSuite.scala | 10 +-- .../PushCNFPredicateThroughFileScan.scala | 4 +- ...PushCNFPredicateThroughHiveTableScan.scala | 2 +- 5 files changed, 22 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3b68dca807edd..c7b91280994ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -211,7 +211,9 @@ trait PredicateHelper extends Logging { * @return the CNF result as sequence of disjunctive expressions. If the number of expressions * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ - def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + def conjunctiveNormalForm( + condition: Expression, + groupExpsFunc: Seq[Expression] => Seq[Expression] = _.toSeq): Seq[Expression] = { val postOrderNodes = postOrderTraversal(condition) val resultStack = new mutable.Stack[Seq[Expression]] val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount @@ -226,8 +228,8 @@ trait PredicateHelper extends Logging { // For each side, there is no need to expand predicates of the same references. // So here we can aggregate predicates of the same qualifier as one single predicate, // for reducing the size of pushed down predicates and corresponding codegen. - val right = groupExpressionsByQualifier(resultStack.pop()) - val left = groupExpressionsByQualifier(resultStack.pop()) + val right = groupExpsFunc(resultStack.pop()) + val left = groupExpsFunc(resultStack.pop()) // Stop the loop whenever the result exceeds the `maxCnfNodeCount` if (left.size * right.size > maxCnfNodeCount) { logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + @@ -249,60 +251,16 @@ trait PredicateHelper extends Logging { resultStack.top } - private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq - } - - /** - * Convert an expression into conjunctive normal form for partition pruning. - * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting [[Or]] - * clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases. - * - * @param condition to be converted into CNF. - * @return the CNF result as sequence of disjunctive expressions. If the number of expressions - * exceeds threshold on converting `Or`, `Seq.empty` is returned. - */ - def conjunctiveNormalFormForPartitionPruning(condition: Expression): Seq[Expression] = { - val postOrderNodes = postOrderTraversal(condition) - val resultStack = new mutable.Stack[Seq[Expression]] - val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount - // Bottom up approach to get CNF of sub-expressions - while (postOrderNodes.nonEmpty) { - val cnf = postOrderNodes.pop() match { - case _: And => - val right = resultStack.pop() - val left = resultStack.pop() - left ++ right - case _: Or => - // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same qualifier as one single predicate, - // for reducing the size of pushed down predicates and corresponding codegen. - val right = groupExpressionsByReference(resultStack.pop()) - val left = groupExpressionsByReference(resultStack.pop()) - // Stop the loop whenever the result exceeds the `maxCnfNodeCount` - if (left.size * right.size > maxCnfNodeCount) { - logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + - "The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " + - s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.") - return Seq.empty - } else { - for { x <- left; y <- right } yield Or(x, y) - } - case other => other :: Nil - } - resultStack.push(cnf) - } - if (resultStack.length != 1) { - logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + - "be something wrong with CNF conversion.") - return Seq.empty - } - resultStack.top + def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = { + conjunctiveNormalForm(condition, + (expressions: Seq[Expression]) => + expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) } - private def groupExpressionsByReference(expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq + def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { + conjunctiveNormalForm(condition, + (expressions: Seq[Expression]) => + expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index f406b7d77ab63..9764bc3b4d216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.rules.Rule object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) => - val predicates = conjunctiveNormalForm(joinCondition) + val predicates = conjunctiveNormalFormAndGroupExpsByQualifier(joinCondition) if (predicates.isEmpty) { j } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala index b449ed5cc0d07..fe8eddc19da3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe // Check CNF conversion with expected expression, assuming the input has non-empty result. private def checkCondition(input: Expression, expected: Expression): Unit = { - val cnf = conjunctiveNormalForm(input) + val cnf = conjunctiveNormalFormAndGroupExpsByQualifier(input) assert(cnf.nonEmpty) val result = cnf.reduceLeft(And) assert(result.semanticEquals(expected)) @@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe Seq(8, 9, 10, 35, 36, 37).foreach { maxCount => withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) { if (maxCount < 36) { - assert(conjunctiveNormalForm(input).isEmpty) + assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).isEmpty) } else { - assert(conjunctiveNormalForm(input).nonEmpty) + assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).nonEmpty) } if (maxCount < 9) { - assert(conjunctiveNormalForm(input2).isEmpty) + assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).isEmpty) } else { - assert(conjunctiveNormalForm(input2).nonEmpty) + assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).nonEmpty) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala index 80eec0f5d4df9..a6cfdb2e2736b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation */ object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case ScanOperation(projectList, conditions, relation: LogicalRelation) if conditions.nonEmpty => - val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) + val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) if (predicates.isEmpty) { plan } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala index 0ab4795541509..4e8016888eca1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -34,7 +34,7 @@ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with Predi def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case ScanOperation(projectList, conditions, relation: HiveTableRelation) if conditions.nonEmpty => - val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And)) + val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) if (predicates.isEmpty) { plan } else { From 69f176372a21e2b10fddec24e3ee05650bcc423e Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 17:02:39 +0800 Subject: [PATCH 13/26] wip --- .../sql/catalyst/optimizer/Optimizer.scala | 8 ++++++- .../PushCNFPredicateThroughFileScan.scala | 22 +++++++++++-------- .../spark/sql/execution/SparkOptimizer.scala | 6 +++-- .../internal/BaseSessionStateBuilder.scala | 11 ++++++++++ .../sql/hive/HiveSessionStateBuilder.scala | 5 ++++- ...PushCNFPredicateThroughHiveTableScan.scala | 22 +++++++++++-------- 6 files changed, 52 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a1a7213664ac8..165347f873a0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -207,7 +207,8 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ + Batch("Final Filter Convert CNF", Once, finalScanFilterConvertRules: _*) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -273,6 +274,11 @@ abstract class Optimizer(catalogManager: CatalogManager) */ def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Override to provide additional rules for final filter convert to CNF. + */ + def finalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = Nil + /** * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that * eventually run in the Optimizer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala index a6cfdb2e2736b..81176bbea9eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -32,14 +32,18 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation */ object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case ScanOperation(projectList, conditions, relation: LogicalRelation) - if conditions.nonEmpty => - val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) - if (predicates.isEmpty) { - plan - } else { - Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } + def apply(plan: LogicalPlan): LogicalPlan = { + var resolved = false + plan resolveOperatorsDown { + case ScanOperation(projectList, conditions, relation: LogicalRelation) + if conditions.nonEmpty && !resolved => + resolved = true + val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) + if (predicates.isEmpty) { + plan + } else { + Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 4c8ae8d48437a..fe7e12a982e5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -37,8 +37,10 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: V2ScanRelationPushDown :: PushCNFPredicateThroughFileScan :: - PruneFileSourcePartitions :: Nil + SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil + + override def finalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = + PushCNFPredicateThroughFileScan :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 3bbdbb002cca8..088cb3081982a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -235,6 +235,9 @@ abstract class BaseSessionStateBuilder( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = super.earlyScanPushDownRules ++ customEarlyScanPushDownRules + override def finalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = + super.finalScanFilterConvertRules ++ customFinalScanFilterConvertRules + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules } @@ -258,6 +261,14 @@ abstract class BaseSessionStateBuilder( */ protected def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Custom final scan filter convert rules to add to the Optimizer. Prefer overriding this instead + * of creating your own Optimizer. + * + * Note that this may NOT depend on the `optimizer` function. + */ + protected def customFinalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = Nil + /** * Planner that converts optimized logical plans to physical plans. * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 0c7ad54dc94b7..043a3a32d4202 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -99,7 +99,10 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session } override def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = - Seq(PushCNFPredicateThroughHiveTableScan, new PruneHiveTablePartitions(session)) + Seq(new PruneHiveTablePartitions(session)) + + override def customFinalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = + Seq(PushCNFPredicateThroughHiveTableScan) /** * Planner that takes into account Hive-specific strategies. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala index 4e8016888eca1..091c90aae0a5b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -31,14 +31,18 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case ScanOperation(projectList, conditions, relation: HiveTableRelation) - if conditions.nonEmpty => - val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) - if (predicates.isEmpty) { - plan - } else { - Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } + def apply(plan: LogicalPlan): LogicalPlan = { + var resolved = false + plan resolveOperatorsDown { + case ScanOperation(projectList, conditions, relation: HiveTableRelation) + if conditions.nonEmpty && !resolved => + resolved = true + val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) + if (predicates.isEmpty) { + plan + } else { + Project(projectList, Filter(predicates.reduceLeft(And), relation)) + } + } } } From 2f576fab532c903e067efffff271bb361421a640 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 12 Jun 2020 21:01:26 +0800 Subject: [PATCH 14/26] fix return bug --- .../spark/sql/execution/PushCNFPredicateThroughFileScan.scala | 4 ++-- .../hive/execution/PushCNFPredicateThroughHiveTableScan.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala index 81176bbea9eb1..dd050b62338df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala @@ -35,12 +35,12 @@ object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateH def apply(plan: LogicalPlan): LogicalPlan = { var resolved = false plan resolveOperatorsDown { - case ScanOperation(projectList, conditions, relation: LogicalRelation) + case op @ ScanOperation(projectList, conditions, relation: LogicalRelation) if conditions.nonEmpty && !resolved => resolved = true val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) if (predicates.isEmpty) { - plan + op } else { Project(projectList, Filter(predicates.reduceLeft(And), relation)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala index 091c90aae0a5b..61fde45167039 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala @@ -34,12 +34,12 @@ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with Predi def apply(plan: LogicalPlan): LogicalPlan = { var resolved = false plan resolveOperatorsDown { - case ScanOperation(projectList, conditions, relation: HiveTableRelation) + case op @ ScanOperation(projectList, conditions, relation: HiveTableRelation) if conditions.nonEmpty && !resolved => resolved = true val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) if (predicates.isEmpty) { - plan + op } else { Project(projectList, Filter(predicates.reduceLeft(And), relation)) } From 94609c81d2b47f67624ae0184a1824a3348297da Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 22 Jun 2020 18:03:33 +0800 Subject: [PATCH 15/26] =?UTF-8?q?Don=E2=80=98t=20add=20new=20rule?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sql/catalyst/optimizer/Optimizer.scala | 8 +-- .../PushCNFPredicateThroughFileScan.scala | 49 ------------------- .../spark/sql/execution/SparkOptimizer.scala | 3 -- .../PruneFileSourcePartitions.scala | 17 +++++-- .../internal/BaseSessionStateBuilder.scala | 11 ----- .../sql/hive/HiveSessionStateBuilder.scala | 5 +- .../execution/PruneHiveTablePartitions.scala | 13 +++-- ...PushCNFPredicateThroughHiveTableScan.scala | 48 ------------------ 8 files changed, 25 insertions(+), 129 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2bf95813b6372..e800ee3b93f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -207,8 +207,7 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ - Batch("Final Filter Convert CNF", Once, finalScanFilterConvertRules: _*) + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -274,11 +273,6 @@ abstract class Optimizer(catalogManager: CatalogManager) */ def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil - /** - * Override to provide additional rules for final filter convert to CNF. - */ - def finalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = Nil - /** * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that * eventually run in the Optimizer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala deleted file mode 100644 index dd050b62338df..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PushCNFPredicateThroughFileScan.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} -import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.LogicalRelation - -/** - * Try converting join condition to conjunctive normal form expression so that more predicates may - * be able to be pushed down. - * To avoid expanding the join condition, the join condition will be kept in the original form even - * when predicate pushdown happens. - */ -object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = { - var resolved = false - plan resolveOperatorsDown { - case op @ ScanOperation(projectList, conditions, relation: LogicalRelation) - if conditions.nonEmpty && !resolved => - resolved = true - val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) - if (predicates.isEmpty) { - op - } else { - Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index fe7e12a982e5a..33b86a2b5340c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -39,9 +39,6 @@ class SparkOptimizer( // TODO: move SchemaPruning into catalyst SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil - override def finalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = - PushCNFPredicateThroughFileScan :: Nil - override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index a7129fb14d1a6..8a9dd9845940b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType * its underlying [[FileScan]]. And the partition filters will be removed in the filters of * returned logical plan. */ -private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { +private[sql] object PruneFileSourcePartitions + extends Rule[LogicalPlan] with PredicateHelper { private def getPartitionKeyFiltersAndDataFilters( sparkSession: SparkSession, @@ -87,8 +88,17 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) + val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) + val (partitionKeyFilters, _) = if (predicates.nonEmpty) { + getPartitionKeyFiltersAndDataFilters( + fsRelation.sparkSession, logicalRelation, partitionSchema, predicates, + logicalRelation.output) + } else { + getPartitionKeyFiltersAndDataFilters( + fsRelation.sparkSession, logicalRelation, partitionSchema, filters, + logicalRelation.output) + } + if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = @@ -104,6 +114,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { op } + case op @ PhysicalOperation(projects, filters, v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) if filters.nonEmpty && scan.readDataSchema.nonEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 088cb3081982a..3bbdbb002cca8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -235,9 +235,6 @@ abstract class BaseSessionStateBuilder( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = super.earlyScanPushDownRules ++ customEarlyScanPushDownRules - override def finalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = - super.finalScanFilterConvertRules ++ customFinalScanFilterConvertRules - override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules } @@ -261,14 +258,6 @@ abstract class BaseSessionStateBuilder( */ protected def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil - /** - * Custom final scan filter convert rules to add to the Optimizer. Prefer overriding this instead - * of creating your own Optimizer. - * - * Note that this may NOT depend on the `optimizer` function. - */ - protected def customFinalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = Nil - /** * Planner that converts optimized logical plans to physical plans. * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 043a3a32d4202..64726755237a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.TableCapabilityCheck import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.hive.execution.{PruneHiveTablePartitions, PushCNFPredicateThroughHiveTableScan} +import org.apache.spark.sql.hive.execution.PruneHiveTablePartitions import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** @@ -101,9 +101,6 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Seq(new PruneHiveTablePartitions(session)) - override def customFinalScanFilterConvertRules: Seq[Rule[LogicalPlan]] = - Seq(PushCNFPredicateThroughHiveTableScan) - /** * Planner that takes into account Hive-specific strategies. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index da6e4c52cf3a7..bf0fb2a686200 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -21,8 +21,8 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.CastSupport -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, ExternalCatalogUtils, HiveTableRelation} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, SubqueryExpression} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. */ private[sql] class PruneHiveTablePartitions(session: SparkSession) - extends Rule[LogicalPlan] with CastSupport { + extends Rule[LogicalPlan] with CastSupport with PredicateHelper { override val conf: SQLConf = session.sessionState.conf @@ -103,7 +103,12 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => - val partitionKeyFilters = getPartitionKeyFilters(filters, relation) + val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) + val partitionKeyFilters = if (predicates.nonEmpty) { + getPartitionKeyFilters(predicates, relation) + } else { + getPartitionKeyFilters(filters, relation) + } if (partitionKeyFilters.nonEmpty) { val newPartitions = prunePartitions(relation, partitionKeyFilters) val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala deleted file mode 100644 index 61fde45167039..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PushCNFPredicateThroughHiveTableScan.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.catalyst.catalog.HiveTableRelation -import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} -import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Try converting join condition to conjunctive normal form expression so that more predicates may - * be able to be pushed down. - * To avoid expanding the join condition, the join condition will be kept in the original form even - * when predicate pushdown happens. - */ -object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = { - var resolved = false - plan resolveOperatorsDown { - case op @ ScanOperation(projectList, conditions, relation: HiveTableRelation) - if conditions.nonEmpty && !resolved => - resolved = true - val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And)) - if (predicates.isEmpty) { - op - } else { - Project(projectList, Filter(predicates.reduceLeft(And), relation)) - } - } - } -} From 326fb49a6e19ff3363c02736cf3eeba1d9fc1928 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 22 Jun 2020 18:05:12 +0800 Subject: [PATCH 16/26] Update PruneFileSourcePartitions.scala --- .../sql/execution/datasources/PruneFileSourcePartitions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8a9dd9845940b..2568121fafbfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -114,7 +114,6 @@ private[sql] object PruneFileSourcePartitions op } - case op @ PhysicalOperation(projects, filters, v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) if filters.nonEmpty && scan.readDataSchema.nonEmpty => From 9322ae62e20715375baf85870e320ed66cf60a06 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 24 Jun 2020 16:34:59 +0800 Subject: [PATCH 17/26] follow comment --- .../sql/catalyst/expressions/predicates.scala | 2 +- .../PruneFileSourcePartitions.scala | 13 ++--- .../execution/PruneHiveTablePartitions.scala | 7 +-- .../hive/execution/HiveTableScanSuite.scala | 56 +++++++++---------- .../PruneFileSourcePartitionsSuite.scala | 51 +++++++++++++++++ 5 files changed, 84 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c67a262221043..719f99dfdac4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -260,7 +260,7 @@ trait PredicateHelper extends Logging { def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => - expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq) + expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 2568121fafbfd..3fa84aedaa78a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -89,15 +89,10 @@ private[sql] object PruneFileSourcePartitions _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) - val (partitionKeyFilters, _) = if (predicates.nonEmpty) { - getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, predicates, - logicalRelation.output) - } else { - getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, - logicalRelation.output) - } + val finalPredicates = if (predicates.nonEmpty) predicates else filters + val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( + fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, + logicalRelation.output) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index bf0fb2a686200..1086133822a56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -104,11 +104,8 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) - val partitionKeyFilters = if (predicates.nonEmpty) { - getPartitionKeyFilters(predicates, relation) - } else { - getPartitionKeyFilters(filters, relation) - } + val finalPredicates = if (predicates.nonEmpty) predicates else filters + val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation) if (partitionKeyFilters.nonEmpty) { val newPartitions = prunePartitions(relation, partitionKeyFilters) val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 053e7591f521f..c7c2a2775c31d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -206,43 +206,39 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH |select col from temp""".stripMargin) } - val scan1 = getHiveTableScanExec( - "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)") - val scan2 = getHiveTableScanExec( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')") - val scan3 = getHiveTableScanExec( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )") - val scan4 = getHiveTableScanExec( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')") - val scan5 = getHiveTableScanExec( - "SELECT * FROM t") - val scan6 = getHiveTableScanExec( - "SELECT * FROM t where p = '1' and i = 2") - - val scan7 = getHiveTableScanExec( + assertPrunedPartitions( + "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", + Array("t(p=1)", "t(p=2)")) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", + Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", + Array("t(p=1)", "t(p=3)")) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", + Array("t(p=1)", "t(p=2)", "t(p=3)")) + assertPrunedPartitions( + "SELECT * FROM t", + Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) + assertPrunedPartitions( + "SELECT * FROM t where p = '1' and i = 2", + Array("t(p=1)")) + assertPrunedPartitions( """ |SELECT i, COUNT(1) FROM ( |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) |) TMP GROUP BY i - """.stripMargin) - - assert(scan1.prunedPartitions.map(_.toString) == - Stream("t(p=1)", "t(p=2)")) - assert(scan2.prunedPartitions.map(_.toString) == - Stream("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) - assert(scan3.prunedPartitions.map(_.toString) == - Stream("t(p=1)", "t(p=3)")) - assert(scan4.prunedPartitions.map(_.toString) == - Stream("t(p=1)", "t(p=2)", "t(p=3)")) - assert(scan5.prunedPartitions.map(_.toString) == - Stream("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) - assert(scan6.prunedPartitions.map(_.toString) == - Stream("t(p=1)")) - assert(scan7.prunedPartitions.map(_.toString) == - Stream("t(p=1)", "t(p=2)")) + """.stripMargin, + Array("t(p=1)", "t(p=2)")) } } + private def assertPrunedPartitions(query: String, expected: Array[String]): Unit = { + val prunedPartitions = getHiveTableScanExec(query).prunedPartitions.map(_.toString).toArray + assert(prunedPartitions.sameElements(expected)) + } + private def getHiveTableScanExec(query: String): HiveTableScanExec = { sql(query).queryExecution.sparkPlan.collectFirst { case p: HiveTableScanExec => p diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index c9c36992906a8..6097c66fb887a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -108,4 +109,54 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } } + + test("SPARK-28169: Convert scan predicate condition to CNF") { + withTable("t", "temp") { + sql( + s""" + |CREATE TABLE t(i int, p string) + |USING PARQUET + |PARTITIONED BY (p) + |""".stripMargin) + spark.range(0, 1000, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Seq(1, 2, 3, 4)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p='$part') + |select col from temp""".stripMargin) + } + + assertPrunedPartitions( + "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", 4) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", 2) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", 3) + assertPrunedPartitions( + "SELECT * FROM t", 4) + assertPrunedPartitions( + "SELECT * FROM t where p = '1' and i = 2", 1) + assertPrunedPartitions( + """ + |SELECT i, COUNT(1) FROM ( + |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) + |) TMP GROUP BY i + """.stripMargin, 2) + } + } + + private def assertPrunedPartitions(query: String, expected: Long): Unit = { + val prunedPartitions = getFileScanExec(query).relation.location.inputFiles.length + assert(prunedPartitions == expected) + } + + private def getFileScanExec(query: String): FileSourceScanExec = { + sql(query).queryExecution.sparkPlan.collectFirst { + case p: FileSourceScanExec => p + }.get + } } From 4a2adcd60c47fc8e3577801f08aee6d7aa44e6b0 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 24 Jun 2020 17:12:17 +0800 Subject: [PATCH 18/26] Update PruneFileSourcePartitionsSuite.scala --- .../sql/hive/execution/PruneFileSourcePartitionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 6097c66fb887a..0e0ea2f0308e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec -import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf From 0e2579d5836c7d1f2a22b3dae9711d09ddbbfe73 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 26 Jun 2020 12:42:41 +0800 Subject: [PATCH 19/26] move test case --- .../hive/execution/HiveTableScanSuite.scala | 50 ---------------- .../PruneHiveTablePartitionsSuite.scala | 58 ++++++++++++++++++- 2 files changed, 57 insertions(+), 51 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index c7c2a2775c31d..6c3dbcbd50423 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -189,56 +189,6 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } } - test("SPARK-28169: Convert scan predicate condition to CNF") { - withTable("t", "temp") { - sql( - s""" - |CREATE TABLE t(i int) - |PARTITIONED BY (p int) - |STORED AS textfile""".stripMargin) - spark.range(0, 1000, 1).selectExpr("id as col") - .createOrReplaceTempView("temp") - - for (part <- Seq(1, 2, 3, 4)) { - sql( - s""" - |INSERT OVERWRITE TABLE t PARTITION (p='$part') - |select col from temp""".stripMargin) - } - - assertPrunedPartitions( - "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", - Array("t(p=1)", "t(p=2)")) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", - Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", - Array("t(p=1)", "t(p=3)")) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", - Array("t(p=1)", "t(p=2)", "t(p=3)")) - assertPrunedPartitions( - "SELECT * FROM t", - Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) - assertPrunedPartitions( - "SELECT * FROM t where p = '1' and i = 2", - Array("t(p=1)")) - assertPrunedPartitions( - """ - |SELECT i, COUNT(1) FROM ( - |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) - |) TMP GROUP BY i - """.stripMargin, - Array("t(p=1)", "t(p=2)")) - } - } - - private def assertPrunedPartitions(query: String, expected: Array[String]): Unit = { - val prunedPartitions = getHiveTableScanExec(query).prunedPartitions.map(_.toString).toArray - assert(prunedPartitions.sameElements(expected)) - } - private def getHiveTableScanExec(query: String): HiveTableScanExec = { sql(query).queryExecution.sparkPlan.collectFirst { case p: HiveTableScanExec => p diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index e41709841a736..55ce55e2c86e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -32,7 +32,7 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil } - test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") { + test("SPARK-15616 statistics pruned after going through PruneHiveTablePartitions") { withTable("test", "temp") { sql( s""" @@ -54,4 +54,60 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes Optimize.execute(analyzed2).stats.sizeInBytes) } } + + test("SPARK-28169: Convert scan predicate condition to CNF") { + withTable("t", "temp") { + sql( + s""" + |CREATE TABLE t(i int) + |PARTITIONED BY (p int) + |STORED AS textfile""".stripMargin) + spark.range(0, 1000, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Seq(1, 2, 3, 4)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p='$part') + |select col from temp""".stripMargin) + } + + assertPrunedPartitions( + "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", + Array("t(p=1)", "t(p=2)")) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", + Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", + Array("t(p=1)", "t(p=3)")) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", + Array("t(p=1)", "t(p=2)", "t(p=3)")) + assertPrunedPartitions( + "SELECT * FROM t", + Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) + assertPrunedPartitions( + "SELECT * FROM t where p = '1' and i = 2", + Array("t(p=1)")) + assertPrunedPartitions( + """ + |SELECT i, COUNT(1) FROM ( + |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) + |) TMP GROUP BY i + """.stripMargin, + Array("t(p=1)", "t(p=2)")) + } + } + + private def assertPrunedPartitions(query: String, expected: Array[String]): Unit = { + val prunedPartitions = getHiveTableScanExec(query).prunedPartitions.map(_.toString).toArray + assert(prunedPartitions.sameElements(expected)) + } + + private def getHiveTableScanExec(query: String): HiveTableScanExec = { + sql(query).queryExecution.sparkPlan.collectFirst { + case p: HiveTableScanExec => p + }.get + } } From 270324ee306f035352b58e77718d73810f1ffa1f Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 26 Jun 2020 12:46:29 +0800 Subject: [PATCH 20/26] Update HiveTableScanSuite.scala --- .../apache/spark/sql/hive/execution/HiveTableScanSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 6c3dbcbd50423..67d7ed0841abb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.And -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ From 219f2005ec2f276074188766f63e1b4b6f8ce30f Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 29 Jun 2020 14:47:01 +0800 Subject: [PATCH 21/26] follow comment --- .../PruneFileSourcePartitionsSuite.scala | 55 ++------------ .../PruneHiveTablePartitionsSuite.scala | 61 ++------------- .../execution/PrunePartitionSuiteBase.scala | 75 +++++++++++++++++++ 3 files changed, 85 insertions(+), 106 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 0e0ea2f0308e1..8f7e1b9719ff5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import org.scalatest.Matchers._ -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -30,12 +29,12 @@ import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRel import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions.broadcast -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { + + convert = "true" object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil @@ -110,53 +109,9 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } - test("SPARK-28169: Convert scan predicate condition to CNF") { - withTable("t", "temp") { - sql( - s""" - |CREATE TABLE t(i int, p string) - |USING PARQUET - |PARTITIONED BY (p) - |""".stripMargin) - spark.range(0, 1000, 1).selectExpr("id as col") - .createOrReplaceTempView("temp") - - for (part <- Seq(1, 2, 3, 4)) { - sql( - s""" - |INSERT OVERWRITE TABLE t PARTITION (p='$part') - |select col from temp""".stripMargin) - } - - assertPrunedPartitions( - "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", 4) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", 2) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", 3) - assertPrunedPartitions( - "SELECT * FROM t", 4) - assertPrunedPartitions( - "SELECT * FROM t where p = '1' and i = 2", 1) - assertPrunedPartitions( - """ - |SELECT i, COUNT(1) FROM ( - |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) - |) TMP GROUP BY i - """.stripMargin, 2) - } - } - - private def assertPrunedPartitions(query: String, expected: Long): Unit = { - val prunedPartitions = getFileScanExec(query).relation.location.inputFiles.length - assert(prunedPartitions == expected) - } - - private def getFileScanExec(query: String): FileSourceScanExec = { + override def getScanExecPartitionSize(query: String): Long = { sql(query).queryExecution.sparkPlan.collectFirst { case p: FileSourceScanExec => p - }.get + }.get.relation.location.inputFiles.length } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index 55ce55e2c86e7..e560ebc7842b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils -class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { + + convert = "false" object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -55,59 +54,9 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes } } - test("SPARK-28169: Convert scan predicate condition to CNF") { - withTable("t", "temp") { - sql( - s""" - |CREATE TABLE t(i int) - |PARTITIONED BY (p int) - |STORED AS textfile""".stripMargin) - spark.range(0, 1000, 1).selectExpr("id as col") - .createOrReplaceTempView("temp") - - for (part <- Seq(1, 2, 3, 4)) { - sql( - s""" - |INSERT OVERWRITE TABLE t PARTITION (p='$part') - |select col from temp""".stripMargin) - } - - assertPrunedPartitions( - "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", - Array("t(p=1)", "t(p=2)")) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", - Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", - Array("t(p=1)", "t(p=3)")) - assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", - Array("t(p=1)", "t(p=2)", "t(p=3)")) - assertPrunedPartitions( - "SELECT * FROM t", - Array("t(p=1)", "t(p=2)", "t(p=3)", "t(p=4)")) - assertPrunedPartitions( - "SELECT * FROM t where p = '1' and i = 2", - Array("t(p=1)")) - assertPrunedPartitions( - """ - |SELECT i, COUNT(1) FROM ( - |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) - |) TMP GROUP BY i - """.stripMargin, - Array("t(p=1)", "t(p=2)")) - } - } - - private def assertPrunedPartitions(query: String, expected: Array[String]): Unit = { - val prunedPartitions = getHiveTableScanExec(query).prunedPartitions.map(_.toString).toArray - assert(prunedPartitions.sameElements(expected)) - } - - private def getHiveTableScanExec(query: String): HiveTableScanExec = { + override def getScanExecPartitionSize(query: String): Long = { sql(query).queryExecution.sparkPlan.collectFirst { case p: HiveTableScanExec => p - }.get + }.get.prunedPartitions.size } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala new file mode 100644 index 0000000000000..3f8653ec07519 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton { + + var convert: String = _ + + test("SPARK-28169: Convert scan predicate condition to CNF") { + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> convert, + HiveUtils.CONVERT_METASTORE_ORC.key -> convert) { + withTable("t", "temp") { + sql( + s""" + |CREATE TABLE t(i int) + |PARTITIONED BY (p int) + |STORED AS PARQUET""".stripMargin) + spark.range(0, 1000, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Seq(1, 2, 3, 4)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p='$part') + |select col from temp""".stripMargin) + } + + assertPrunedPartitions( + "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", 4) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", 2) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", 3) + assertPrunedPartitions( + "SELECT * FROM t", 4) + assertPrunedPartitions( + "SELECT * FROM t where p = '1' and i = 2", 1) + assertPrunedPartitions( + """ + |SELECT i, COUNT(1) FROM ( + |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) + |) TMP GROUP BY i + """.stripMargin, 2) + } + } + } + + protected def assertPrunedPartitions(query: String, expected: Long): Unit = { + assert(getScanExecPartitionSize(query) == expected) + } + + protected def getScanExecPartitionSize(query: String): Long +} From f21cf43fce759c8589d38a5a0d68213a1f47565b Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 30 Jun 2020 11:09:23 +0800 Subject: [PATCH 22/26] follow comment --- .../sql/catalyst/expressions/predicates.scala | 20 +++++++++++++ .../PruneFileSourcePartitionsSuite.scala | 2 +- .../PruneHiveTablePartitionsSuite.scala | 2 +- .../execution/PrunePartitionSuiteBase.scala | 28 +++++++++---------- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 719f99dfdac4a..2efb801bfc53f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -251,12 +251,32 @@ trait PredicateHelper extends Logging { resultStack.top } + /** + * Convert an expression to conjunctive normal form when pushing predicates through Join, + * when expand predicates, we can group by the qualifier avoiding generate unnecessary + * expression to control the length of final result since there are multiple tables. + * @param condition condition need to be convert + * @return expression seq in conjunctive normal form of input expression, if length exceeds + * the threshold [[SQLConf.MAX_CNF_NODE_COUNT]] or length != 1, return empty Seq + */ def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = { conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) } + /** + * Convert an expression to conjunctive normal form when pushing predicates for partition pruning, + * when expand predicates, we can group by the reference avoiding generate unnecessary expression + * to control the length of final result since here we just have one table. In partition pruning + * strategies, we split filters by [[splitConjunctivePredicates]] and partition filters by judging + * if it's references is subset of partCols, if we combine expressions group by reference when + * expand predicate of [[Or]], it won't impact final predicate pruning result since + * [[splitConjunctivePredicates]] won't split [[Or]] expression. + * @param condition condition need to be convert + * @return expression seq in conjunctive normal form of input expression, if length exceeds + * the threshold [[SQLConf.MAX_CNF_NODE_COUNT]] or length != 1, return empty Seq + */ def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 8f7e1b9719ff5..fd490cf9f34c2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { - convert = "true" + override def format: String = "parquet" object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index e560ebc7842b3..eeb5abab74fd1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { - convert = "false" + override def format(): String = "hive" object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala index 3f8653ec07519..7d0c200f5779d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala @@ -24,17 +24,17 @@ import org.apache.spark.sql.test.SQLTestUtils abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton { - var convert: String = _ + protected def format: String test("SPARK-28169: Convert scan predicate condition to CNF") { - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> convert, - HiveUtils.CONVERT_METASTORE_ORC.key -> convert) { - withTable("t", "temp") { + withTempView("temp") { + withTable("t") { sql( s""" - |CREATE TABLE t(i int) - |PARTITIONED BY (p int) - |STORED AS PARQUET""".stripMargin) + |CREATE TABLE t(i INT, p STRING) + |USING $format + |PARTITIONED BY (p)""".stripMargin) + spark.range(0, 1000, 1).selectExpr("id as col") .createOrReplaceTempView("temp") @@ -42,26 +42,26 @@ abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with sql( s""" |INSERT OVERWRITE TABLE t PARTITION (p='$part') - |select col from temp""".stripMargin) + |SELECT col FROM temp""".stripMargin) } assertPrunedPartitions( "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (i = 1 or p = '2')", 4) + "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4) assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '3' and i = 3 )", 2) + "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2) assertPrunedPartitions( - "SELECT * FROM t WHERE (p = '1' and i = 2) or (p = '2' or p = '3')", 3) + "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3) assertPrunedPartitions( "SELECT * FROM t", 4) assertPrunedPartitions( - "SELECT * FROM t where p = '1' and i = 2", 1) + "SELECT * FROM t WHERE p = '1' AND i = 2", 1) assertPrunedPartitions( """ |SELECT i, COUNT(1) FROM ( - |SELECT * FROM t where p = '1' OR (p = '2' AND i = 1) - |) TMP GROUP BY i + |SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1) + |) tmp GROUP BY i """.stripMargin, 2) } } From 35b5813f7015992d2a9ef03a668394e190ba1007 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 30 Jun 2020 19:26:18 +0800 Subject: [PATCH 23/26] follow comment --- .../apache/spark/sql/execution/DataSourceScanExec.scala | 2 +- .../hive/execution/PruneFileSourcePartitionsSuite.scala | 8 ++++---- .../hive/execution/PruneHiveTablePartitionsSuite.scala | 5 +++-- .../sql/hive/execution/PrunePartitionSuiteBase.scala | 7 ++++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 0ae39cf8560e6..3430ba7a243b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -205,7 +205,7 @@ case class FileSourceScanExec( private def isDynamicPruningFilter(e: Expression): Boolean = e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - @transient private lazy val selectedPartitions: Array[PartitionDirectory] = { + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index fd490cf9f34c2..24aecb0274ece 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -109,9 +109,9 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { } } - override def getScanExecPartitionSize(query: String): Long = { - sql(query).queryExecution.sparkPlan.collectFirst { + override def getScanExecPartitionSize(plan: SparkPlan): Long = { + plan.collectFirst { case p: FileSourceScanExec => p - }.get.relation.location.inputFiles.length + }.get.selectedPartitions.length } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index eeb5abab74fd1..ebf02fc75b477 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution.SparkPlan class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { @@ -54,8 +55,8 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { } } - override def getScanExecPartitionSize(query: String): Long = { - sql(query).queryExecution.sparkPlan.collectFirst { + override def getScanExecPartitionSize(plan: SparkPlan): Long = { + plan.collectFirst { case p: HiveTableScanExec => p }.get.prunedPartitions.size } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala index 7d0c200f5779d..d088061cdc6e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -68,8 +68,9 @@ abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with } protected def assertPrunedPartitions(query: String, expected: Long): Unit = { - assert(getScanExecPartitionSize(query) == expected) + val plan = sql(query).queryExecution.sparkPlan + assert(getScanExecPartitionSize(plan) == expected) } - protected def getScanExecPartitionSize(query: String): Long + protected def getScanExecPartitionSize(plan: SparkPlan): Long } From 3df019ae83ad8519329fa4419b7000ebdbe982fa Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 1 Jul 2020 09:37:04 +0800 Subject: [PATCH 24/26] Update predicates.scala --- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2efb801bfc53f..7aefe5b48be4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -213,7 +213,7 @@ trait PredicateHelper extends Logging { */ protected def conjunctiveNormalForm( condition: Expression, - groupExpsFunc: Seq[Expression] => Seq[Expression] = _.toSeq): Seq[Expression] = { + groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = { val postOrderNodes = postOrderTraversal(condition) val resultStack = new mutable.Stack[Seq[Expression]] val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount From 1b8466eb5068f97c4f27d25af59dd1a518b4cabe Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 1 Jul 2020 09:49:01 +0800 Subject: [PATCH 25/26] follow comment --- .../sql/catalyst/expressions/predicates.scala | 32 +++++++++---------- .../PruneHiveTablePartitionsSuite.scala | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7aefe5b48be4d..7423b6fe484dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -255,31 +255,31 @@ trait PredicateHelper extends Logging { * Convert an expression to conjunctive normal form when pushing predicates through Join, * when expand predicates, we can group by the qualifier avoiding generate unnecessary * expression to control the length of final result since there are multiple tables. - * @param condition condition need to be convert - * @return expression seq in conjunctive normal form of input expression, if length exceeds - * the threshold [[SQLConf.MAX_CNF_NODE_COUNT]] or length != 1, return empty Seq + * + * @param condition condition need to be converted + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = { - conjunctiveNormalForm(condition, - (expressions: Seq[Expression]) => + conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) } /** - * Convert an expression to conjunctive normal form when pushing predicates for partition pruning, - * when expand predicates, we can group by the reference avoiding generate unnecessary expression - * to control the length of final result since here we just have one table. In partition pruning - * strategies, we split filters by [[splitConjunctivePredicates]] and partition filters by judging - * if it's references is subset of partCols, if we combine expressions group by reference when - * expand predicate of [[Or]], it won't impact final predicate pruning result since + * Convert an expression to conjunctive normal form for predicate pushdown and partition pruning. + * When expanding predicates, this method groups expressions by their references for reducing + * the size of pushed down predicates and corresponding codegen. In partition pruning strategies, + * we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's + * references is subset of partCols, if we combine expressions group by reference when expand + * predicate of [[Or]], it won't impact final predicate pruning result since * [[splitConjunctivePredicates]] won't split [[Or]] expression. - * @param condition condition need to be convert - * @return expression seq in conjunctive normal form of input expression, if length exceeds - * the threshold [[SQLConf.MAX_CNF_NODE_COUNT]] or length != 1, return empty Seq + * + * @param condition condition need to be converted + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { - conjunctiveNormalForm(condition, - (expressions: Seq[Expression]) => + conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index ebf02fc75b477..c29e889c3a941 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -32,7 +32,7 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil } - test("SPARK-15616 statistics pruned after going through PruneHiveTablePartitions") { + test("SPARK-15616: statistics pruned after going through PruneHiveTablePartitions") { withTable("test", "temp") { sql( s""" From e2777c99144f91b463924530a0a89e6f1a39ab66 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 1 Jul 2020 16:39:43 +0800 Subject: [PATCH 26/26] follow comment --- .../spark/sql/catalyst/expressions/predicates.scala | 4 ++-- .../optimizer/PushCNFPredicateThroughJoin.scala | 2 +- .../ConjunctiveNormalFormPredicateSuite.scala | 10 +++++----- .../datasources/PruneFileSourcePartitions.scala | 2 +- .../sql/hive/execution/PruneHiveTablePartitions.scala | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7423b6fe484dc..527618b8e2c5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -260,7 +260,7 @@ trait PredicateHelper extends Logging { * @return the CNF result as sequence of disjunctive expressions. If the number of expressions * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ - def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = { + def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = { conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) } @@ -278,7 +278,7 @@ trait PredicateHelper extends Logging { * @return the CNF result as sequence of disjunctive expressions. If the number of expressions * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ - def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { + def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = { conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index cccac032a2b0e..47e9527ead7c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -38,7 +38,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) if canPushThrough(joinType) => - val predicates = conjunctiveNormalFormAndGroupExpsByQualifier(joinCondition) + val predicates = CNFWithGroupExpressionsByQualifier(joinCondition) if (predicates.isEmpty) { j } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala index fe8eddc19da3e..793abccd79405 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe // Check CNF conversion with expected expression, assuming the input has non-empty result. private def checkCondition(input: Expression, expected: Expression): Unit = { - val cnf = conjunctiveNormalFormAndGroupExpsByQualifier(input) + val cnf = CNFWithGroupExpressionsByQualifier(input) assert(cnf.nonEmpty) val result = cnf.reduceLeft(And) assert(result.semanticEquals(expected)) @@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe Seq(8, 9, 10, 35, 36, 37).foreach { maxCount => withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) { if (maxCount < 36) { - assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).isEmpty) + assert(CNFWithGroupExpressionsByQualifier(input).isEmpty) } else { - assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).nonEmpty) + assert(CNFWithGroupExpressionsByQualifier(input).nonEmpty) } if (maxCount < 9) { - assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).isEmpty) + assert(CNFWithGroupExpressionsByQualifier(input2).isEmpty) } else { - assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).nonEmpty) + assert(CNFWithGroupExpressionsByQualifier(input2).nonEmpty) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3fa84aedaa78a..576a826faf894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -88,7 +88,7 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) + val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) val finalPredicates = if (predicates.nonEmpty) predicates else filters val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index 1086133822a56..c4885f2842597 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -103,7 +103,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => - val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) + val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) val finalPredicates = if (predicates.nonEmpty) predicates else filters val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation) if (partitionKeyFilters.nonEmpty) {