Skip to content

Commit 2661c3d

Browse files
committed
more reuse cases
1 parent 04536c9 commit 2661c3d

File tree

2 files changed

+72
-26
lines changed

2 files changed

+72
-26
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
package org.apache.spark.sql.execution.exchange
1919

20+
import java.util.Objects
21+
2022
import scala.collection.mutable
21-
import scala.collection.mutable.ArrayBuffer
2223

2324
import org.apache.spark.broadcast
2425
import org.apache.spark.rdd.RDD
2526
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
2728
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2829
import org.apache.spark.sql.catalyst.rules.Rule
30+
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
2931
import org.apache.spark.sql.execution._
3032
import org.apache.spark.sql.internal.SQLConf
3133
import org.apache.spark.sql.types.StructType
@@ -52,6 +54,13 @@ abstract class Exchange extends UnaryExecNode {
5254
case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange)
5355
extends LeafExecNode {
5456

57+
override def equals(that: Any): Boolean = that match {
58+
case ReusedExchangeExec(output, child) => this.child == output && this.child.eq(child)
59+
case _ => false
60+
}
61+
62+
override def hashCode: Int = Objects.hash(output, child)
63+
5564
override def supportsColumnar: Boolean = child.supportsColumnar
5665

5766
// Ignore this wrapper for canonicalizing.
@@ -113,27 +122,38 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
113122
// have the same schema
114123
val exchanges = mutable.Map[StructType, (Exchange, mutable.Map[SparkPlan, Exchange])]()
115124

116-
def reuse(plan: SparkPlan): SparkPlan = plan.transform {
117-
case exchange: Exchange =>
118-
val (firstSameSchemaExchange, sameResultExchanges) =
119-
exchanges.getOrElseUpdate(exchange.schema, (exchange, mutable.Map()))
120-
if (firstSameSchemaExchange.ne(exchange)) {
121-
if (sameResultExchanges.isEmpty) {
122-
sameResultExchanges += firstSameSchemaExchange.canonicalized -> firstSameSchemaExchange
123-
}
124-
val sameResultExchange =
125-
sameResultExchanges.getOrElseUpdate(exchange.canonicalized, exchange)
126-
if (sameResultExchange.ne(exchange)) {
127-
ReusedExchangeExec(exchange.output, sameResultExchange)
125+
def reuse(plan: SparkPlan): SparkPlan = {
126+
// Track exchanges that are replaced to reused exchanges to be able to fix ReusedExchangeExec
127+
// nodes referencing to them
128+
val reuseExchanges = mutable.Map[TreeNodeRef, Exchange]()
129+
130+
plan.transform {
131+
case exchange: Exchange =>
132+
val (firstSameSchemaExchange, sameResultExchanges) =
133+
exchanges.getOrElseUpdate(exchange.schema, (exchange, mutable.Map()))
134+
if (firstSameSchemaExchange.ne(exchange)) {
135+
if (sameResultExchanges.isEmpty) {
136+
sameResultExchanges +=
137+
firstSameSchemaExchange.canonicalized -> firstSameSchemaExchange
138+
}
139+
val sameResultExchange =
140+
sameResultExchanges.getOrElseUpdate(exchange.canonicalized, exchange)
141+
if (sameResultExchange.ne(exchange)) {
142+
reuseExchanges += new TreeNodeRef(exchange) -> sameResultExchange
143+
ReusedExchangeExec(exchange.output, sameResultExchange)
144+
} else {
145+
exchange
146+
}
128147
} else {
129148
exchange
130149
}
131-
} else {
132-
exchange
150+
case reuseExchange @ ReusedExchangeExec(output, child) =>
151+
reuseExchanges.get(new TreeNodeRef(child)).map(ReusedExchangeExec(output, _))
152+
.getOrElse(reuseExchange)
153+
case other => other.transformExpressions {
154+
case sub: ExecSubqueryExpression =>
155+
sub.withNewPlan(reuse(sub.plan).asInstanceOf[BaseSubqueryExec])
133156
}
134-
case other => other.transformExpressions {
135-
case sub: ExecSubqueryExpression =>
136-
sub.withNewPlan(reuse(sub.plan).asInstanceOf[BaseSubqueryExec])
137157
}
138158
}
139159

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,17 +1536,43 @@ class SubquerySuite extends QueryTest with SharedSparkSession {
15361536

15371537
val plan = df.queryExecution.executedPlan
15381538

1539-
val countExchange = plan.collectInPlanAndSubqueries({ case _: Exchange => 1 }).sum
1540-
val countReusedExchange =
1541-
plan.collectInPlanAndSubqueries({ case _: ReusedExchangeExec => 1 }).sum
1539+
val exchangeIds = plan.collectInPlanAndSubqueries { case e: Exchange => e.id }
1540+
val reusedExchangeIds = plan.collectInPlanAndSubqueries {
1541+
case re: ReusedExchangeExec => re.child.id
1542+
}
1543+
1544+
if (reuse) {
1545+
assert(exchangeIds.size == 2, "Exchange reusing not working correctly")
1546+
assert(reusedExchangeIds.size == 3, "Exchange reusing not working correctly")
1547+
assert(reusedExchangeIds.forall(exchangeIds.contains(_)),
1548+
"ReusedExchangeExec should reuse an existing exchange")
1549+
} else {
1550+
assert(exchangeIds.size == 5, "expect 5 Exchange when not reusing")
1551+
assert(reusedExchangeIds.size == 0, "expect 0 ReusedExchangeExec when not reusing")
1552+
}
1553+
1554+
val df2 = sql(
1555+
"""
1556+
SELECT
1557+
(SELECT min(a.key) FROM testData AS a JOIN testData AS b ON b.key = a.key),
1558+
(SELECT max(a.key) FROM testData AS a JOIN testData AS b ON b.key = a.key)
1559+
""".stripMargin)
1560+
1561+
val plan2 = df2.queryExecution.executedPlan
1562+
1563+
val exchangeIds2 = plan2.collectInPlanAndSubqueries { case e: Exchange => e.id }
1564+
val reusedExchangeIds2 = plan2.collectInPlanAndSubqueries {
1565+
case re: ReusedExchangeExec => re.child.id
1566+
}
15421567

15431568
if (reuse) {
1544-
assert(countExchange == 2, "Exchange reusing not working correctly")
1545-
assert(countReusedExchange == 3, "Exchange reusing not working correctly")
1569+
assert(exchangeIds2.size == 3, "Exchange reusing not working correctly")
1570+
assert(reusedExchangeIds2.size == 3, "Exchange reusing not working correctly")
1571+
assert(reusedExchangeIds2.forall(exchangeIds2.contains(_)),
1572+
"ReusedExchangeExec should reuse an existing exchange")
15461573
} else {
1547-
assert(countExchange == 5, "expect 5 Exchange when not reusing")
1548-
assert(countReusedExchange == 0,
1549-
"expect 0 ReusedExchangeExec when not reusing")
1574+
assert(exchangeIds2.size == 6, "expect 6 Exchange when not reusing")
1575+
assert(reusedExchangeIds2.size == 0, "expect 0 ReusedExchangeExec when not reusing")
15501576
}
15511577
}
15521578
}

0 commit comments

Comments
 (0)