diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 05a5ff45b8fb1..527618b8e2c5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -211,7 +211,9 @@ trait PredicateHelper extends Logging { * @return the CNF result as sequence of disjunctive expressions. If the number of expressions * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ - protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { + protected def conjunctiveNormalForm( + condition: Expression, + groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = { val postOrderNodes = postOrderTraversal(condition) val resultStack = new mutable.Stack[Seq[Expression]] val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount @@ -226,8 +228,8 @@ trait PredicateHelper extends Logging { // For each side, there is no need to expand predicates of the same references. // So here we can aggregate predicates of the same qualifier as one single predicate, // for reducing the size of pushed down predicates and corresponding codegen. - val right = groupExpressionsByQualifier(resultStack.pop()) - val left = groupExpressionsByQualifier(resultStack.pop()) + val right = groupExpsFunc(resultStack.pop()) + val left = groupExpsFunc(resultStack.pop()) // Stop the loop whenever the result exceeds the `maxCnfNodeCount` if (left.size * right.size > maxCnfNodeCount) { logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + @@ -249,8 +251,36 @@ trait PredicateHelper extends Logging { resultStack.top } - private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = { - expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq + /** + * Convert an expression to conjunctive normal form when pushing predicates through Join, + * when expand predicates, we can group by the qualifier avoiding generate unnecessary + * expression to control the length of final result since there are multiple tables. + * + * @param condition condition need to be converted + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. + */ + def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = { + conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => + expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) + } + + /** + * Convert an expression to conjunctive normal form for predicate pushdown and partition pruning. + * When expanding predicates, this method groups expressions by their references for reducing + * the size of pushed down predicates and corresponding codegen. In partition pruning strategies, + * we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's + * references is subset of partCols, if we combine expressions group by reference when expand + * predicate of [[Or]], it won't impact final predicate pruning result since + * [[splitConjunctivePredicates]] won't split [[Or]] expression. + * + * @param condition condition need to be converted + * @return the CNF result as sequence of disjunctive expressions. If the number of expressions + * exceeds threshold on converting `Or`, `Seq.empty` is returned. + */ + def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = { + conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => + expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala index 109e5f993c02e..47e9527ead7c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala @@ -38,7 +38,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) if canPushThrough(joinType) => - val predicates = conjunctiveNormalForm(joinCondition) + val predicates = CNFWithGroupExpressionsByQualifier(joinCondition) if (predicates.isEmpty) { j } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala index b449ed5cc0d07..793abccd79405 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala @@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe // Check CNF conversion with expected expression, assuming the input has non-empty result. private def checkCondition(input: Expression, expected: Expression): Unit = { - val cnf = conjunctiveNormalForm(input) + val cnf = CNFWithGroupExpressionsByQualifier(input) assert(cnf.nonEmpty) val result = cnf.reduceLeft(And) assert(result.semanticEquals(expected)) @@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe Seq(8, 9, 10, 35, 36, 37).foreach { maxCount => withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) { if (maxCount < 36) { - assert(conjunctiveNormalForm(input).isEmpty) + assert(CNFWithGroupExpressionsByQualifier(input).isEmpty) } else { - assert(conjunctiveNormalForm(input).nonEmpty) + assert(CNFWithGroupExpressionsByQualifier(input).nonEmpty) } if (maxCount < 9) { - assert(conjunctiveNormalForm(input2).isEmpty) + assert(CNFWithGroupExpressionsByQualifier(input2).isEmpty) } else { - assert(conjunctiveNormalForm(input2).nonEmpty) + assert(CNFWithGroupExpressionsByQualifier(input2).nonEmpty) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 0ae39cf8560e6..3430ba7a243b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -205,7 +205,7 @@ case class FileSourceScanExec( private def isDynamicPruningFilter(e: Expression): Boolean = e.find(_.isInstanceOf[PlanExpression[_]]).isDefined - @transient private lazy val selectedPartitions: Array[PartitionDirectory] = { + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index a7129fb14d1a6..576a826faf894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType * its underlying [[FileScan]]. And the partition filters will be removed in the filters of * returned logical plan. */ -private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { +private[sql] object PruneFileSourcePartitions + extends Rule[LogicalPlan] with PredicateHelper { private def getPartitionKeyFiltersAndDataFilters( sparkSession: SparkSession, @@ -87,8 +88,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => + val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) + val finalPredicates = if (predicates.nonEmpty) predicates else filters val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) + fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, + logicalRelation.output) + if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index da6e4c52cf3a7..c4885f2842597 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -21,8 +21,8 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.CastSupport -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, ExternalCatalogUtils, HiveTableRelation} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, SubqueryExpression} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. */ private[sql] class PruneHiveTablePartitions(session: SparkSession) - extends Rule[LogicalPlan] with CastSupport { + extends Rule[LogicalPlan] with CastSupport with PredicateHelper { override val conf: SQLConf = session.sessionState.conf @@ -103,7 +103,9 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => - val partitionKeyFilters = getPartitionKeyFilters(filters, relation) + val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) + val finalPredicates = if (predicates.nonEmpty) predicates else filters + val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation) if (partitionKeyFilters.nonEmpty) { val newPartitions = prunePartitions(relation, partitionKeyFilters) val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index c9c36992906a8..24aecb0274ece 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -19,22 +19,22 @@ package org.apache.spark.sql.hive.execution import org.scalatest.Matchers._ -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions.broadcast -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { + + override def format: String = "parquet" object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil @@ -108,4 +108,10 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } } + + override def getScanExecPartitionSize(plan: SparkPlan): Long = { + plan.collectFirst { + case p: FileSourceScanExec => p + }.get.selectedPartitions.length + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index e41709841a736..c29e889c3a941 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.execution.SparkPlan -class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { + + override def format(): String = "hive" object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -32,7 +32,7 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil } - test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") { + test("SPARK-15616: statistics pruned after going through PruneHiveTablePartitions") { withTable("test", "temp") { sql( s""" @@ -54,4 +54,10 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes Optimize.execute(analyzed2).stats.sizeInBytes) } } + + override def getScanExecPartitionSize(plan: SparkPlan): Long = { + plan.collectFirst { + case p: HiveTableScanExec => p + }.get.prunedPartitions.size + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala new file mode 100644 index 0000000000000..d088061cdc6e5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton { + + protected def format: String + + test("SPARK-28169: Convert scan predicate condition to CNF") { + withTempView("temp") { + withTable("t") { + sql( + s""" + |CREATE TABLE t(i INT, p STRING) + |USING $format + |PARTITIONED BY (p)""".stripMargin) + + spark.range(0, 1000, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Seq(1, 2, 3, 4)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p='$part') + |SELECT col FROM temp""".stripMargin) + } + + assertPrunedPartitions( + "SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2) + assertPrunedPartitions( + "SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3) + assertPrunedPartitions( + "SELECT * FROM t", 4) + assertPrunedPartitions( + "SELECT * FROM t WHERE p = '1' AND i = 2", 1) + assertPrunedPartitions( + """ + |SELECT i, COUNT(1) FROM ( + |SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1) + |) tmp GROUP BY i + """.stripMargin, 2) + } + } + } + + protected def assertPrunedPartitions(query: String, expected: Long): Unit = { + val plan = sql(query).queryExecution.sparkPlan + assert(getScanExecPartitionSize(plan) == expected) + } + + protected def getScanExecPartitionSize(plan: SparkPlan): Long +}