-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-28169][SQL] Convert scan predicate condition to CNF #28805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 27 commits
3356bac
346a1b4
250c7b3
15d62be
39e85ad
d8f7c9e
8856453
697a3a9
7e8319e
3734866
b253af3
478a7a8
603660b
69f1763
2f576fa
e71c45c
94609c8
326fb49
9322ae6
4a2adcd
0e2579d
270324e
219f200
f21cf43
35b5813
3df019a
1b8466e
e2777c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,7 +211,9 @@ trait PredicateHelper extends Logging { | |
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions | ||
* exceeds threshold on converting `Or`, `Seq.empty` is returned. | ||
*/ | ||
protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { | ||
protected def conjunctiveNormalForm( | ||
condition: Expression, | ||
groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = { | ||
val postOrderNodes = postOrderTraversal(condition) | ||
val resultStack = new mutable.Stack[Seq[Expression]] | ||
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount | ||
|
@@ -226,8 +228,8 @@ trait PredicateHelper extends Logging { | |
// For each side, there is no need to expand predicates of the same references. | ||
// So here we can aggregate predicates of the same qualifier as one single predicate, | ||
// for reducing the size of pushed down predicates and corresponding codegen. | ||
val right = groupExpressionsByQualifier(resultStack.pop()) | ||
val left = groupExpressionsByQualifier(resultStack.pop()) | ||
val right = groupExpsFunc(resultStack.pop()) | ||
val left = groupExpsFunc(resultStack.pop()) | ||
// Stop the loop whenever the result exceeds the `maxCnfNodeCount` | ||
if (left.size * right.size > maxCnfNodeCount) { | ||
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + | ||
|
@@ -249,8 +251,36 @@ trait PredicateHelper extends Logging { | |
resultStack.top | ||
} | ||
|
||
private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = { | ||
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq | ||
/** | ||
* Convert an expression to conjunctive normal form when pushing predicates through Join, | ||
* when expand predicates, we can group by the qualifier avoiding generate unnecessary | ||
* expression to control the length of final result since there are multiple tables. | ||
* | ||
* @param condition condition need to be converted | ||
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions | ||
* exceeds threshold on converting `Or`, `Seq.empty` is returned. | ||
*/ | ||
def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = { | ||
|
||
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => | ||
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit format:
|
||
} | ||
|
||
/** | ||
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning. | ||
* When expanding predicates, this method groups expressions by their references for reducing | ||
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies, | ||
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's | ||
* references is subset of partCols, if we combine expressions group by reference when expand | ||
* predicate of [[Or]], it won't impact final predicate pruning result since | ||
* [[splitConjunctivePredicates]] won't split [[Or]] expression. | ||
* | ||
* @param condition condition need to be converted | ||
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions | ||
* exceeds threshold on converting `Or`, `Seq.empty` is returned. | ||
*/ | ||
def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => | ||
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) | ||
} | ||
|
||
wangyum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType | |
* its underlying [[FileScan]]. And the partition filters will be removed in the filters of | ||
* returned logical plan. | ||
*/ | ||
private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { | ||
private[sql] object PruneFileSourcePartitions | ||
extends Rule[LogicalPlan] with PredicateHelper { | ||
|
||
private def getPartitionKeyFiltersAndDataFilters( | ||
sparkSession: SparkSession, | ||
|
@@ -87,8 +88,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { | |
_, | ||
_)) | ||
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => | ||
val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) | ||
|
||
val finalPredicates = if (predicates.nonEmpty) predicates else filters | ||
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( | ||
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) | ||
fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, | ||
logicalRelation.output) | ||
|
||
if (partitionKeyFilters.nonEmpty) { | ||
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) | ||
val prunedFsRelation = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.hive.execution | ||
|
||
import org.apache.spark.sql.QueryTest | ||
import org.apache.spark.sql.execution.SparkPlan | ||
import org.apache.spark.sql.hive.test.TestHiveSingleton | ||
import org.apache.spark.sql.test.SQLTestUtils | ||
|
||
abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton { | ||
|
||
protected def format: String | ||
|
||
test("SPARK-28169: Convert scan predicate condition to CNF") { | ||
withTempView("temp") { | ||
withTable("t") { | ||
sql( | ||
s""" | ||
|CREATE TABLE t(i INT, p STRING) | ||
|USING $format | ||
|PARTITIONED BY (p)""".stripMargin) | ||
|
||
spark.range(0, 1000, 1).selectExpr("id as col") | ||
.createOrReplaceTempView("temp") | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for (part <- Seq(1, 2, 3, 4)) { | ||
sql( | ||
s""" | ||
|INSERT OVERWRITE TABLE t PARTITION (p='$part') | ||
|SELECT col FROM temp""".stripMargin) | ||
} | ||
|
||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t", 4) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE p = '1' AND i = 2", 1) | ||
assertPrunedPartitions( | ||
""" | ||
|SELECT i, COUNT(1) FROM ( | ||
|SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1) | ||
|) tmp GROUP BY i | ||
""".stripMargin, 2) | ||
} | ||
} | ||
} | ||
|
||
protected def assertPrunedPartitions(query: String, expected: Long): Unit = { | ||
val plan = sql(query).queryExecution.sparkPlan | ||
assert(getScanExecPartitionSize(plan) == expected) | ||
} | ||
|
||
protected def getScanExecPartitionSize(plan: SparkPlan): Long | ||
} |
Uh oh!
There was an error while loading. Please reload this page.