Skip to content

Commit d0c83f3

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-32302][SQL] Partially push down disjunctive predicates through Join/Partitions
### What changes were proposed in this pull request? In #28733 and #28805, CNF conversion is used to push down disjunctive predicates through join and partitions pruning. It's a good improvement, however, converting all the predicates in CNF can lead to a very long result, even with grouping functions over expressions. For example, for the following predicate ``` (p0 = '1' AND p1 = '1') OR (p0 = '2' AND p1 = '2') OR (p0 = '3' AND p1 = '3') OR (p0 = '4' AND p1 = '4') OR (p0 = '5' AND p1 = '5') OR (p0 = '6' AND p1 = '6') OR (p0 = '7' AND p1 = '7') OR (p0 = '8' AND p1 = '8') OR (p0 = '9' AND p1 = '9') OR (p0 = '10' AND p1 = '10') OR (p0 = '11' AND p1 = '11') OR (p0 = '12' AND p1 = '12') OR (p0 = '13' AND p1 = '13') OR (p0 = '14' AND p1 = '14') OR (p0 = '15' AND p1 = '15') OR (p0 = '16' AND p1 = '16') OR (p0 = '17' AND p1 = '17') OR (p0 = '18' AND p1 = '18') OR (p0 = '19' AND p1 = '19') OR (p0 = '20' AND p1 = '20') ``` will be converted into a long query(130K characters) in Hive metastore, and there will be error: ``` javax.jdo.JDOException: Exception thrown when executing query : SELECT DISTINCT 'org.apache.hadoop.hive.metastore.model.MPartition' AS NUCLEUS_TYPE,A0.CREATE_TIME,A0.LAST_ACCESS_TIME,A0.PART_NAME,A0.PART_ID,A0.PART_NAME AS NUCORDER0 FROM PARTITIONS A0 LEFT OUTER JOIN TBLS B0 ON A0.TBL_ID = B0.TBL_ID LEFT OUTER JOIN DBS C0 ON B0.DB_ID = C0.DB_ID WHERE B0.TBL_NAME = ? AND C0."NAME" = ? AND ((((((A0.PART_NAME LIKE '%/p1=1' ESCAPE '\' ) OR (A0.PART_NAME LIKE '%/p1=2' ESCAPE '\' )) OR (A0.PART_NAME LIKE '%/p1=3' ESCAPE '\' )) OR ((A0.PART_NAME LIKE '%/p1=4' ESCAPE '\' ) O ... ``` Essentially, we just need to traverse predicate and extract the convertible sub-predicates like what we did in #24598. There is no need to maintain the CNF result set. ### Why are the changes needed? A better implementation for pushing down disjunctive and complex predicates. The pushed down predicates is always equal or shorter than the CNF result. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests Closes #29101 from gengliangwang/pushJoin. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent c2afe1c commit d0c83f3

File tree

10 files changed

+208
-324
lines changed

10 files changed

+208
-324
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 39 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -202,125 +202,50 @@ trait PredicateHelper extends Logging {
202202
}
203203

204204
/**
205-
* Convert an expression into conjunctive normal form.
206-
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
207-
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
208-
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
209-
*
210-
* @param condition to be converted into CNF.
211-
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
212-
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
205+
* Returns a filter that its reference is a subset of `outputSet` and it contains the maximum
206+
* constraints from `condition`. This is used for predicate pushdown.
207+
* When there is no such filter, `None` is returned.
213208
*/
214-
protected def conjunctiveNormalForm(
209+
protected def extractPredicatesWithinOutputSet(
215210
condition: Expression,
216-
groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = {
217-
val postOrderNodes = postOrderTraversal(condition)
218-
val resultStack = new mutable.Stack[Seq[Expression]]
219-
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
220-
// Bottom up approach to get CNF of sub-expressions
221-
while (postOrderNodes.nonEmpty) {
222-
val cnf = postOrderNodes.pop() match {
223-
case _: And =>
224-
val right = resultStack.pop()
225-
val left = resultStack.pop()
226-
left ++ right
227-
case _: Or =>
228-
// For each side, there is no need to expand predicates of the same references.
229-
// So here we can aggregate predicates of the same qualifier as one single predicate,
230-
// for reducing the size of pushed down predicates and corresponding codegen.
231-
val right = groupExpsFunc(resultStack.pop())
232-
val left = groupExpsFunc(resultStack.pop())
233-
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
234-
if (left.size * right.size > maxCnfNodeCount) {
235-
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
236-
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
237-
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
238-
return Seq.empty
239-
} else {
240-
for { x <- left; y <- right } yield Or(x, y)
241-
}
242-
case other => other :: Nil
211+
outputSet: AttributeSet): Option[Expression] = condition match {
212+
case And(left, right) =>
213+
val leftResultOptional = extractPredicatesWithinOutputSet(left, outputSet)
214+
val rightResultOptional = extractPredicatesWithinOutputSet(right, outputSet)
215+
(leftResultOptional, rightResultOptional) match {
216+
case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult))
217+
case (Some(leftResult), None) => Some(leftResult)
218+
case (None, Some(rightResult)) => Some(rightResult)
219+
case _ => None
243220
}
244-
resultStack.push(cnf)
245-
}
246-
if (resultStack.length != 1) {
247-
logWarning("The length of CNF conversion result stack is supposed to be 1. There might " +
248-
"be something wrong with CNF conversion.")
249-
return Seq.empty
250-
}
251-
resultStack.top
252-
}
253-
254-
/**
255-
* Convert an expression to conjunctive normal form when pushing predicates through Join,
256-
* when expand predicates, we can group by the qualifier avoiding generate unnecessary
257-
* expression to control the length of final result since there are multiple tables.
258-
*
259-
* @param condition condition need to be converted
260-
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
261-
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
262-
*/
263-
def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = {
264-
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
265-
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
266-
}
267-
268-
/**
269-
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
270-
* When expanding predicates, this method groups expressions by their references for reducing
271-
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
272-
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
273-
* references is subset of partCols, if we combine expressions group by reference when expand
274-
* predicate of [[Or]], it won't impact final predicate pruning result since
275-
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
276-
*
277-
* @param condition condition need to be converted
278-
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
279-
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
280-
*/
281-
def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
282-
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
283-
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
284-
}
285221

286-
/**
287-
* Iterative post order traversal over a binary tree built by And/Or clauses with two stacks.
288-
* For example, a condition `(a And b) Or c`, the postorder traversal is
289-
* (`a`,`b`, `And`, `c`, `Or`).
290-
* Following is the complete algorithm. After step 2, we get the postorder traversal in
291-
* the second stack.
292-
* 1. Push root to first stack.
293-
* 2. Loop while first stack is not empty
294-
* 2.1 Pop a node from first stack and push it to second stack
295-
* 2.2 Push the children of the popped node to first stack
296-
*
297-
* @param condition to be traversed as binary tree
298-
* @return sub-expressions in post order traversal as a stack.
299-
* The first element of result stack is the leftmost node.
300-
*/
301-
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
302-
val stack = new mutable.Stack[Expression]
303-
val result = new mutable.Stack[Expression]
304-
stack.push(condition)
305-
while (stack.nonEmpty) {
306-
val node = stack.pop()
307-
node match {
308-
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
309-
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
310-
case Not(Not(a)) => stack.push(a)
311-
case a And b =>
312-
result.push(node)
313-
stack.push(a)
314-
stack.push(b)
315-
case a Or b =>
316-
result.push(node)
317-
stack.push(a)
318-
stack.push(b)
319-
case _ =>
320-
result.push(node)
222+
// The Or predicate is convertible when both of its children can be pushed down.
223+
// That is to say, if one/both of the children can be partially pushed down, the Or
224+
// predicate can be partially pushed down as well.
225+
//
226+
// Here is an example used to explain the reason.
227+
// Let's say we have
228+
// condition: (a1 AND a2) OR (b1 AND b2),
229+
// outputSet: AttributeSet(a1, b1)
230+
// a1 and b1 is convertible, while a2 and b2 is not.
231+
// The predicate can be converted as
232+
// (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2)
233+
// As per the logical in And predicate, we can push down (a1 OR b1).
234+
case Or(left, right) =>
235+
for {
236+
lhs <- extractPredicatesWithinOutputSet(left, outputSet)
237+
rhs <- extractPredicatesWithinOutputSet(right, outputSet)
238+
} yield Or(lhs, rhs)
239+
240+
// Here we assume all the `Not` operators is already below all the `And` and `Or` operators
241+
// after the optimization rule `BooleanSimplification`, so that we don't need to handle the
242+
// `Not` operators here.
243+
case other =>
244+
if (other.references.subsetOf(outputSet)) {
245+
Some(other)
246+
} else {
247+
None
321248
}
322-
}
323-
result
324249
}
325250
}
326251

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
5151
override protected val excludedOnceBatches: Set[String] =
5252
Set(
5353
"PartitionPruning",
54-
"Extract Python UDFs",
55-
"Push CNF predicate through join")
54+
"Extract Python UDFs")
5655

5756
protected def fixedPoint =
5857
FixedPoint(
@@ -123,8 +122,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
123122
rulesWithoutInferFiltersFromConstraints: _*) ::
124123
// Set strategy to Once to avoid pushing filter every time because we do not change the
125124
// join condition.
126-
Batch("Push CNF predicate through join", Once,
127-
PushCNFPredicateThroughJoin) :: Nil
125+
Batch("Push extra predicate through join", fixedPoint,
126+
PushExtraPredicateThroughJoin,
127+
PushDownPredicates) :: Nil
128128
}
129129

130130
val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper}
20+
import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper}
2121
import org.apache.spark.sql.catalyst.plans._
2222
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
2323
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2425

2526
/**
26-
* Try converting join condition to conjunctive normal form expression so that more predicates may
27-
* be able to be pushed down.
27+
* Try pushing down disjunctive join condition into left and right child.
2828
* To avoid expanding the join condition, the join condition will be kept in the original form even
2929
* when predicate pushdown happens.
3030
*/
31-
object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
31+
object PushExtraPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
32+
33+
private val processedJoinConditionTag = TreeNodeTag[Expression]("processedJoinCondition")
3234

3335
private def canPushThrough(joinType: JoinType): Boolean = joinType match {
3436
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true
@@ -38,22 +40,28 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
3840
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
3941
case j @ Join(left, right, joinType, Some(joinCondition), hint)
4042
if canPushThrough(joinType) =>
41-
val predicates = CNFWithGroupExpressionsByQualifier(joinCondition)
42-
if (predicates.isEmpty) {
43+
val alreadyProcessed = j.getTagValue(processedJoinConditionTag).exists { condition =>
44+
condition.semanticEquals(joinCondition)
45+
}
46+
47+
lazy val filtersOfBothSide = splitConjunctivePredicates(joinCondition).filter { f =>
48+
f.deterministic && f.references.nonEmpty &&
49+
!f.references.subsetOf(left.outputSet) && !f.references.subsetOf(right.outputSet)
50+
}
51+
lazy val leftExtraCondition =
52+
filtersOfBothSide.flatMap(extractPredicatesWithinOutputSet(_, left.outputSet))
53+
lazy val rightExtraCondition =
54+
filtersOfBothSide.flatMap(extractPredicatesWithinOutputSet(_, right.outputSet))
55+
56+
if (alreadyProcessed || (leftExtraCondition.isEmpty && rightExtraCondition.isEmpty)) {
4357
j
4458
} else {
45-
val pushDownCandidates = predicates.filter(_.deterministic)
46-
lazy val leftFilterConditions =
47-
pushDownCandidates.filter(_.references.subsetOf(left.outputSet))
48-
lazy val rightFilterConditions =
49-
pushDownCandidates.filter(_.references.subsetOf(right.outputSet))
50-
5159
lazy val newLeft =
52-
leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
60+
leftExtraCondition.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
5361
lazy val newRight =
54-
rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
62+
rightExtraCondition.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
5563

56-
joinType match {
64+
val newJoin = joinType match {
5765
case _: InnerLike | LeftSemi =>
5866
Join(newLeft, newRight, joinType, Some(joinCondition), hint)
5967
case RightOuter =>
@@ -63,6 +71,8 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
6371
case other =>
6472
throw new IllegalStateException(s"Unexpected join type: $other")
6573
}
66-
}
74+
newJoin.setTagValue(processedJoinConditionTag, joinCondition)
75+
newJoin
76+
}
6777
}
6878
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -545,19 +545,6 @@ object SQLConf {
545545
.booleanConf
546546
.createWithDefault(true)
547547

548-
val MAX_CNF_NODE_COUNT =
549-
buildConf("spark.sql.optimizer.maxCNFNodeCount")
550-
.internal()
551-
.doc("Specifies the maximum allowable number of conjuncts in the result of CNF " +
552-
"conversion. If the conversion exceeds the threshold, an empty sequence is returned. " +
553-
"For example, CNF conversion of (a && b) || (c && d) generates " +
554-
"four conjuncts (a || c) && (a || d) && (b || c) && (b || d).")
555-
.version("3.1.0")
556-
.intConf
557-
.checkValue(_ >= 0,
558-
"The depth of the maximum rewriting conjunction normal form must be positive.")
559-
.createWithDefault(128)
560-
561548
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
562549
.internal()
563550
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
@@ -2954,8 +2941,6 @@ class SQLConf extends Serializable with Logging {
29542941

29552942
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
29562943

2957-
def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT)
2958-
29592944
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
29602945

29612946
def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)

0 commit comments

Comments
 (0)