diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index a0ad1c72806f5..318f24b2c84e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -239,7 +239,9 @@ object QueryExecution { * are correct, insert whole stage code gen, and try to reduce the work done by reusing exchanges * and subqueries. */ - private[execution] def preparations(sparkSession: SparkSession): Seq[Rule[SparkPlan]] = + private[execution] def preparations( + sparkSession: SparkSession, + subQuery: Boolean = false): Seq[Rule[SparkPlan]] = Seq( // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op // as the original plan is hidden behind `AdaptiveSparkPlanExec`. @@ -249,10 +251,9 @@ object QueryExecution { EnsureRequirements(sparkSession.sessionState.conf), ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf, sparkSession.sessionState.columnarRules), - CollapseCodegenStages(sparkSession.sessionState.conf), - ReuseExchange(sparkSession.sessionState.conf), + CollapseCodegenStages(sparkSession.sessionState.conf)) ++ + (if (subQuery) Nil else Seq(ReuseExchange(sparkSession.sessionState.conf))) :+ ReuseSubquery(sparkSession.sessionState.conf) - ) /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal @@ -283,7 +284,7 @@ object QueryExecution { * Prepare the [[SparkPlan]] for execution. */ def prepareExecutedPlan(spark: SparkSession, plan: SparkPlan): SparkPlan = { - prepareForExecution(preparations(spark), plan) + prepareForExecution(preparations(spark, true), plan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b79d3a278bb3e..1e3c916a07a06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -93,11 +93,13 @@ case class HashAggregateExec( // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { - sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { - case null | "" => None - case fallbackStartsAt => - val splits = fallbackStartsAt.split(",").map(_.trim) - Some((splits.head.toInt, splits.last.toInt)) + Option(sqlContext).flatMap { + _.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => + val splits = fallbackStartsAt.split(",").map(_.trim) + Some((splits.head.toInt, splits.last.toInt)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index a1dde415d6e8b..337f303c9349d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.exchange import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD @@ -107,35 +106,39 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.exchangeReuseEnabled) { return plan } - // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. - val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]() - - // Replace a Exchange duplicate with a ReusedExchange - def reuse: PartialFunction[Exchange, SparkPlan] = { - case exchange: Exchange => - val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]()) - val samePlan = sameSchema.find { e => - exchange.sameResult(e) - } - if (samePlan.isDefined) { - // Keep the output of this exchange, the following plans require that to resolve - // attributes. - ReusedExchangeExec(exchange.output, samePlan.get) - } else { - sameSchema += exchange - exchange + // To avoid costly canonicalization of an exchange: + // - we use its schema first to check if it can be replaced to a reused exchange at all + // - we insert an exchange into the map of canonicalized plans only when at least 2 exchange + // have the same schema + val exchanges = mutable.Map[StructType, (Exchange, mutable.Map[SparkPlan, Exchange])]() + + def reuse(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case exchange: Exchange => + val (firstSameSchemaExchange, sameResultExchanges) = + exchanges.getOrElseUpdate(exchange.schema, (exchange, mutable.Map())) + if (firstSameSchemaExchange.ne(exchange)) { + if (sameResultExchanges.isEmpty) { + sameResultExchanges += + firstSameSchemaExchange.canonicalized -> firstSameSchemaExchange + } + val sameResultExchange = + sameResultExchanges.getOrElseUpdate(exchange.canonicalized, exchange) + if (sameResultExchange.ne(exchange)) { + ReusedExchangeExec(exchange.output, sameResultExchange) + } else { + exchange + } + } else { + exchange + } + case other => other.transformExpressions { + case sub: ExecSubqueryExpression => + sub.withNewPlan(reuse(sub.plan).asInstanceOf[BaseSubqueryExec]) } + } } - plan transformUp { - case exchange: Exchange => reuse(exchange) - } transformAllExpressions { - // Lookup inside subqueries for duplicate exchanges - case in: InSubqueryExec => - val newIn = in.plan.transformUp { - case exchange: Exchange => reuse(exchange) - } - in.copy(plan = newIn.asInstanceOf[BaseSubqueryExec]) - } + reuse(plan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4281f01e2756a..6b9a2f5b31495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -61,7 +61,7 @@ case class ShuffleExchangeExec( override def nodeName: String = "Exchange" - private val serializer: Serializer = + private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) @transient lazy val inputRDD: RDD[InternalRow] = child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index c2270c57eb941..556272f184ad2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -104,6 +104,10 @@ case class ScalarSubquery( require(updated, s"$this has not finished") Literal.create(result, dataType).doGenCode(ctx, ev) } + + override lazy val canonicalized: ScalarSubquery = { + copy(plan = plan.canonicalized.asInstanceOf[BaseSubqueryExec], exprId = ExprId(0)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5020c1047f8dd..e61492bdae25c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.datasources.FileScanRDD +import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -1389,6 +1390,62 @@ class SubquerySuite extends QueryTest with SharedSparkSession { } } + test("Exchange reuse across all subquery levels") { + Seq(true, false).foreach { reuse => + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> reuse.toString) { + val df = sql( + """ + |SELECT + | (SELECT max(a.key) FROM testData AS a JOIN testData AS b ON b.key = a.key), + | a.key + |FROM testData AS a + |JOIN testData AS b ON b.key = a.key + """.stripMargin) + + val plan = df.queryExecution.executedPlan + + val exchangeIds = plan.collectInPlanAndSubqueries { case e: Exchange => e.id } + val reusedExchangeIds = plan.collectInPlanAndSubqueries { + case re: ReusedExchangeExec => re.child.id + } + + if (reuse) { + assert(exchangeIds.size == 2, "Exchange reusing not working correctly") + assert(reusedExchangeIds.size == 3, "Exchange reusing not working correctly") + assert(reusedExchangeIds.forall(exchangeIds.contains(_)), + "ReusedExchangeExec should reuse an existing exchange") + } else { + assert(exchangeIds.size == 5, "expect 5 Exchange when not reusing") + assert(reusedExchangeIds.size == 0, "expect 0 ReusedExchangeExec when not reusing") + } + + val df2 = sql( + """ + SELECT + (SELECT min(a.key) FROM testData AS a JOIN testData AS b ON b.key = a.key), + (SELECT max(a.key) FROM testData AS a JOIN testData2 AS b ON b.a = a.key) + """.stripMargin) + + val plan2 = df2.queryExecution.executedPlan + + val exchangeIds2 = plan2.collectInPlanAndSubqueries { case e: Exchange => e.id } + val reusedExchangeIds2 = plan2.collectInPlanAndSubqueries { + case re: ReusedExchangeExec => re.child.id + } + + if (reuse) { + assert(exchangeIds2.size == 4, "Exchange reusing not working correctly") + assert(reusedExchangeIds2.size == 2, "Exchange reusing not working correctly") + assert(reusedExchangeIds2.forall(exchangeIds2.contains(_)), + "ReusedExchangeExec should reuse an existing exchange") + } else { + assert(exchangeIds2.size == 6, "expect 6 Exchange when not reusing") + assert(reusedExchangeIds2.size == 0, "expect 0 ReusedExchangeExec when not reusing") + } + } + } + } + test("Scalar subquery name should start with scalar-subquery#") { val df = sql("SELECT a FROM l WHERE a = (SELECT max(c) FROM r WHERE c = 1)".stripMargin) var subqueryExecs: ArrayBuffer[SubqueryExec] = ArrayBuffer.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dea0b1ce937c..f01825967305f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -470,7 +470,7 @@ class PlannerSuite extends SharedSparkSession { Inner, None, shuffle, - shuffle) + shuffle.copy()) val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) {