Skip to content

Commit 16c78d6

Browse files
committed
better solution for pushing extra predicates through join
1 parent b05f309 commit 16c78d6

File tree

9 files changed

+99
-378
lines changed

9 files changed

+99
-378
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 40 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -201,126 +201,50 @@ trait PredicateHelper extends Logging {
201201
case e => e.children.forall(canEvaluateWithinJoin)
202202
}
203203

204-
/**
205-
* Convert an expression into conjunctive normal form.
206-
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
207-
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
208-
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
209-
*
210-
* @param condition to be converted into CNF.
211-
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
212-
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
204+
/*
205+
* Returns a filter that it's output is a subset of `outputSet` and it contains all possible
206+
* constraints from `condition`. This is used for predicate pushdown.
207+
* When there is no such convertible filter, `None` is returned.
213208
*/
214-
protected def conjunctiveNormalForm(
215-
condition: Expression,
216-
groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = {
217-
val postOrderNodes = postOrderTraversal(condition)
218-
val resultStack = new mutable.Stack[Seq[Expression]]
219-
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
220-
// Bottom up approach to get CNF of sub-expressions
221-
while (postOrderNodes.nonEmpty) {
222-
val cnf = postOrderNodes.pop() match {
223-
case _: And =>
224-
val right = resultStack.pop()
225-
val left = resultStack.pop()
226-
left ++ right
227-
case _: Or =>
228-
// For each side, there is no need to expand predicates of the same references.
229-
// So here we can aggregate predicates of the same qualifier as one single predicate,
230-
// for reducing the size of pushed down predicates and corresponding codegen.
231-
val right = groupExpsFunc(resultStack.pop())
232-
val left = groupExpsFunc(resultStack.pop())
233-
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
234-
if (left.size * right.size > maxCnfNodeCount) {
235-
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
236-
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
237-
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
238-
return Seq.empty
239-
} else {
240-
for { x <- left; y <- right } yield Or(x, y)
241-
}
242-
case other => other :: Nil
209+
protected def convertibleFilter(
210+
condition: Expression,
211+
outputSet: AttributeSet): Option[Expression] = condition match {
212+
case And(left, right) =>
213+
val leftResultOptional = convertibleFilter(left, outputSet)
214+
val rightResultOptional = convertibleFilter(right, outputSet)
215+
(leftResultOptional, rightResultOptional) match {
216+
case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult))
217+
case (Some(leftResult), None) => Some(leftResult)
218+
case (None, Some(rightResult)) => Some(rightResult)
219+
case _ => None
243220
}
244-
resultStack.push(cnf)
245-
}
246-
if (resultStack.length != 1) {
247-
logWarning("The length of CNF conversion result stack is supposed to be 1. There might " +
248-
"be something wrong with CNF conversion.")
249-
return Seq.empty
250-
}
251-
resultStack.top
252-
}
253-
254-
/**
255-
* Convert an expression to conjunctive normal form when pushing predicates through Join,
256-
* when expand predicates, we can group by the qualifier avoiding generate unnecessary
257-
* expression to control the length of final result since there are multiple tables.
258-
*
259-
* @param condition condition need to be converted
260-
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
261-
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
262-
*/
263-
def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = {
264-
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
265-
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
266-
}
267-
268-
/**
269-
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
270-
* When expanding predicates, this method groups expressions by their references for reducing
271-
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
272-
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
273-
* references is subset of partCols, if we combine expressions group by reference when expand
274-
* predicate of [[Or]], it won't impact final predicate pruning result since
275-
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
276-
*
277-
* @param condition condition need to be converted
278-
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
279-
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
280-
*/
281-
def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
282-
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
283-
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
284-
}
285221

286-
/**
287-
* Iterative post order traversal over a binary tree built by And/Or clauses with two stacks.
288-
* For example, a condition `(a And b) Or c`, the postorder traversal is
289-
* (`a`,`b`, `And`, `c`, `Or`).
290-
* Following is the complete algorithm. After step 2, we get the postorder traversal in
291-
* the second stack.
292-
* 1. Push root to first stack.
293-
* 2. Loop while first stack is not empty
294-
* 2.1 Pop a node from first stack and push it to second stack
295-
* 2.2 Push the children of the popped node to first stack
296-
*
297-
* @param condition to be traversed as binary tree
298-
* @return sub-expressions in post order traversal as a stack.
299-
* The first element of result stack is the leftmost node.
300-
*/
301-
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
302-
val stack = new mutable.Stack[Expression]
303-
val result = new mutable.Stack[Expression]
304-
stack.push(condition)
305-
while (stack.nonEmpty) {
306-
val node = stack.pop()
307-
node match {
308-
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
309-
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
310-
case Not(Not(a)) => stack.push(a)
311-
case a And b =>
312-
result.push(node)
313-
stack.push(a)
314-
stack.push(b)
315-
case a Or b =>
316-
result.push(node)
317-
stack.push(a)
318-
stack.push(b)
319-
case _ =>
320-
result.push(node)
222+
// The Or predicate is convertible when both of its children can be pushed down.
223+
// That is to say, if one/both of the children can be partially pushed down, the Or
224+
// predicate can be partially pushed down as well.
225+
//
226+
// Here is an example used to explain the reason.
227+
// Let's say we have
228+
// (a1 AND a2) OR (b1 AND b2),
229+
// a1 and b1 is convertible, while a2 and b2 is not.
230+
// The predicate can be converted as
231+
// (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2)
232+
// As per the logical in And predicate, we can push down (a1 OR b1).
233+
case Or(left, right) =>
234+
for {
235+
lhs <- convertibleFilter(left, outputSet)
236+
rhs <- convertibleFilter(right, outputSet)
237+
} yield Or(lhs, rhs)
238+
239+
// Here we assume all the `Not` operators is already below all the `And` and `Or` operators
240+
// after the optimization rule `BooleanSimplification`, so that we don't need to handle the
241+
// `Not` operators here.
242+
case other =>
243+
if (other.references.subsetOf(outputSet)) {
244+
Some(other)
245+
} else {
246+
None
321247
}
322-
}
323-
result
324248
}
325249
}
326250

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
5151
override protected val excludedOnceBatches: Set[String] =
5252
Set(
5353
"PartitionPruning",
54-
"Extract Python UDFs",
55-
"Push CNF predicate through join")
54+
"Extract Python UDFs")
5655

5756
protected def fixedPoint =
5857
FixedPoint(
@@ -123,8 +122,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
123122
rulesWithoutInferFiltersFromConstraints: _*) ::
124123
// Set strategy to Once to avoid pushing filter every time because we do not change the
125124
// join condition.
126-
Batch("Push CNF predicate through join", Once,
127-
PushCNFPredicateThroughJoin) :: Nil
125+
Batch("Push extra predicate through join", fixedPoint,
126+
PushExtraPredicateThroughJoin,
127+
PushDownPredicates) :: Nil
128128
}
129129

130130
val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala

Lines changed: 0 additions & 68 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -545,19 +545,6 @@ object SQLConf {
545545
.booleanConf
546546
.createWithDefault(true)
547547

548-
val MAX_CNF_NODE_COUNT =
549-
buildConf("spark.sql.optimizer.maxCNFNodeCount")
550-
.internal()
551-
.doc("Specifies the maximum allowable number of conjuncts in the result of CNF " +
552-
"conversion. If the conversion exceeds the threshold, an empty sequence is returned. " +
553-
"For example, CNF conversion of (a && b) || (c && d) generates " +
554-
"four conjuncts (a || c) && (a || d) && (b || c) && (b || d).")
555-
.version("3.1.0")
556-
.intConf
557-
.checkValue(_ >= 0,
558-
"The depth of the maximum rewriting conjunction normal form must be positive.")
559-
.createWithDefault(128)
560-
561548
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
562549
.internal()
563550
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
@@ -2948,8 +2935,6 @@ class SQLConf extends Serializable with Logging {
29482935

29492936
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
29502937

2951-
def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT)
2952-
29532938
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
29542939

29552940
def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)

0 commit comments

Comments
 (0)