Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,13 @@ trait PredicateHelper extends Logging {

/**
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
* When expanding predicates, this method groups expressions by their references for reducing
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
* references is subset of partCols, if we combine expressions group by reference when expand
* predicate of [[Or]], it won't impact final predicate pruning result since
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
def CNFConversion(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => expressions)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ private[sql] object PruneFileSourcePartitions
val partitionColumns =
relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
val (partitionFilters, remainingFilters) = normalizedFilters.partition(f =>
f.references.subsetOf(partitionSet)
)

(ExpressionSet(partitionFilters), dataFilters)
// Try extracting more convertible partition filters from the remaining filters by converting
// them into CNF.
val remainingFilterInCnf = remainingFilters.flatMap(CNFConversion)
val extraPartitionFilters =
remainingFilterInCnf.filter(f => f.references.subsetOf(partitionSet))

(ExpressionSet(partitionFilters ++ extraPartitionFilters), remainingFilters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    val (extraPartitionFilters, otherFilters) = remainingFilterInCnf.partition(f =>
      f.references.subsetOf(partitionSet)
    )
    (ExpressionSet(partitionFilters ++ extraPartitionFilters), otherFilters)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that way, otherFilters can be very long, which leads to a longer codegen... I am avoiding that on purpose. Let me add comment here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay.

}

private def rebuildPhysicalOperation(
Expand Down Expand Up @@ -88,12 +94,9 @@ private[sql] object PruneFileSourcePartitions
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
val finalPredicates = if (predicates.nonEmpty) predicates else filters
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates,
fsRelation.sparkSession, logicalRelation, partitionSchema, filters,
logicalRelation.output)

if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions.CNFConversion
import org.apache.spark.sql.internal.SQLConf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is not necessary.


/**
Expand Down Expand Up @@ -54,9 +55,15 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output)
val partitionColumnSet = AttributeSet(relation.partitionCols)
ExpressionSet(normalizedFilters.filter { f =>
val (partitionFilters, remainingFilters) = normalizedFilters.partition { f =>
!f.references.isEmpty && f.references.subsetOf(partitionColumnSet)
})
}
// Try extracting more convertible partition filters from the remaining filters by converting
// them into CNF.
val remainingFilterInCnf = remainingFilters.flatMap(CNFConversion)
val extraPartitionFilters = remainingFilterInCnf.filter(f =>
!f.references.isEmpty && f.references.subsetOf(partitionColumnSet))
ExpressionSet(partitionFilters ++ extraPartitionFilters)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused that seems CNFConversion won't change references, You don't need to call a splitConjunctivePredicates to each expr in remainingFilterInCnf to extract more predicate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filters here is already processed with splitConjunctivePredicates in PhysicalOperation.unapply. That's why the original code before #28805 doesn't call splitConjunctivePredicates either.

}

/**
Expand Down Expand Up @@ -103,7 +110,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
val predicates = CNFConversion(filters.reduceLeft(And))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: conjunctiveNormalForm(filters.reduceLeft(And), identity)?

val finalPredicates = if (predicates.nonEmpty) predicates else filters
val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation)
if (partitionKeyFilters.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with
}
}

test("SPARK-32284: Avoid pushing down too many predicates in partition pruning") {
withTempView("temp") {
withTable("t") {
sql(
s"""
|CREATE TABLE t(i INT, p0 INT, p1 INT)
|USING $format
|PARTITIONED BY (p0, p1)""".stripMargin)

spark.range(0, 10, 1).selectExpr("id as col")
.createOrReplaceTempView("temp")

for (part <- (0 to 25)) {
sql(
s"""
|INSERT OVERWRITE TABLE t PARTITION (p0='$part', p1='$part')
|SELECT col FROM temp""".stripMargin)
}
val scale = 20
val predicate = (1 to scale).map(i => s"(p0 = '$i' AND p1 = '$i')").mkString(" OR ")
assertPrunedPartitions(s"SELECT * FROM t WHERE $predicate", scale)
}
}
}

protected def assertPrunedPartitions(query: String, expected: Long): Unit = {
val plan = sql(query).queryExecution.sparkPlan
assert(getScanExecPartitionSize(plan) == expected)
Expand Down