Skip to content

Commit 0732e44

Browse files
yhuang-dbpeter-toth
authored andcommitted
[SPARK-53809][SQL] Add canonicalization for DataSourceV2ScanRelation
### What changes were proposed in this pull request? This PR proposes to add `doCanonicalize` function for DataSourceV2ScanRelation. The implementation is similar to [the one in BatchScanExec](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala#L150), as well as the [the one in LogicalRelation](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala#L52). ### Why are the changes needed? Query optimization rules such as MergeScalarSubqueries check if two plans are identical by [comparing their canonicalized form](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala#L219). For DSv2, for physical plan, the canonicalization goes down in the child hierarchy to the BatchScanExec, which [has a doCanonicalize function](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala#L150); for logical plan, the canonicalization goes down to the DataSourceV2ScanRelation, which, however, does not have a doCanonicalize function. As a result, two logical plans who are semantically identical are not identified. Moreover, for reference, [DSv1 LogicalRelation](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala#L52) also has `doCanonicalize()`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new unit test is added to show that `MergeScalarSubqueries` is working for DataSourceV2ScanRelation. For a query ```sql select (select max(i) from df) as max_i, (select min(i) from df) as min_i ``` Before introducing the canonicalization, the plan is ``` == Parsed Logical Plan == 'Project [scalar-subquery#2 [] AS max_i#3, scalar-subquery#4 [] AS min_i#5] : :- 'Project [unresolvedalias('max('i))] : : +- 'UnresolvedRelation [df], [], false : +- 'Project [unresolvedalias('min('i))] : +- 'UnresolvedRelation [df], [], false +- OneRowRelation == Analyzed Logical Plan == max_i: int, min_i: int Project [scalar-subquery#2 [] AS max_i#3, scalar-subquery#4 [] AS min_i#5] : :- Aggregate [max(i#0) AS max(i)#7] : : +- SubqueryAlias df : : +- View (`df`, [i#0, j#1]) : : +- RelationV2[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 : +- Aggregate [min(i#10) AS min(i)#9] : +- SubqueryAlias df : +- View (`df`, [i#10, j#11]) : +- RelationV2[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 +- OneRowRelation == Optimized Logical Plan == Project [scalar-subquery#2 [] AS max_i#3, scalar-subquery#4 [] AS min_i#5] : :- Aggregate [max(i#0) AS max(i)#7] : : +- Project [i#0] : : +- RelationV2[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 : +- Aggregate [min(i#10) AS min(i)#9] : +- Project [i#10] : +- RelationV2[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 +- OneRowRelation == Physical Plan == AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 0 +- *(1) Project [Subquery subquery#2, [id=#32] AS max_i#3, Subquery subquery#4, [id=#33] AS min_i#5] : :- Subquery subquery#2, [id=#32] : : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 1 +- *(2) HashAggregate(keys=[], functions=[max(i#0)], output=[max(i)#7]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=58] +- *(1) HashAggregate(keys=[], functions=[partial_max(i#0)], output=[max#14]) +- *(1) Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- == Initial Plan == HashAggregate(keys=[], functions=[max(i#0)], output=[max(i)#7]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=19] +- HashAggregate(keys=[], functions=[partial_max(i#0)], output=[max#14]) +- Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] : +- Subquery subquery#4, [id=#33] : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 1 +- *(2) HashAggregate(keys=[], functions=[min(i#10)], output=[min(i)#9]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=63] +- *(1) HashAggregate(keys=[], functions=[partial_min(i#10)], output=[min#15]) +- *(1) Project [i#10] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- == Initial Plan == HashAggregate(keys=[], functions=[min(i#10)], output=[min(i)#9]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=30] +- HashAggregate(keys=[], functions=[partial_min(i#10)], output=[min#15]) +- Project [i#10] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- *(1) Scan OneRowRelation[] +- == Initial Plan == Project [Subquery subquery#2, [id=#32] AS max_i#3, Subquery subquery#4, [id=#33] AS min_i#5] : :- Subquery subquery#2, [id=#32] : : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 1 +- *(2) HashAggregate(keys=[], functions=[max(i#0)], output=[max(i)#7]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=58] +- *(1) HashAggregate(keys=[], functions=[partial_max(i#0)], output=[max#14]) +- *(1) Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- == Initial Plan == HashAggregate(keys=[], functions=[max(i#0)], output=[max(i)#7]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=19] +- HashAggregate(keys=[], functions=[partial_max(i#0)], output=[max#14]) +- Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] : +- Subquery subquery#4, [id=#33] : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 1 +- *(2) HashAggregate(keys=[], functions=[min(i#10)], output=[min(i)#9]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=63] +- *(1) HashAggregate(keys=[], functions=[partial_min(i#10)], output=[min#15]) +- *(1) Project [i#10] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- == Initial Plan == HashAggregate(keys=[], functions=[min(i#10)], output=[min(i)#9]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=30] +- HashAggregate(keys=[], functions=[partial_min(i#10)], output=[min#15]) +- Project [i#10] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- Scan OneRowRelation[] ``` After introducing the canonicalization, the plan is as following, where you can see **ReusedSubquery** ``` == Parsed Logical Plan == 'Project [scalar-subquery#2 [] AS max_i#3, scalar-subquery#4 [] AS min_i#5] : :- 'Project [unresolvedalias('max('i))] : : +- 'UnresolvedRelation [df], [], false : +- 'Project [unresolvedalias('min('i))] : +- 'UnresolvedRelation [df], [], false +- OneRowRelation == Analyzed Logical Plan == max_i: int, min_i: int Project [scalar-subquery#2 [] AS max_i#3, scalar-subquery#4 [] AS min_i#5] : :- Aggregate [max(i#0) AS max(i)#7] : : +- SubqueryAlias df : : +- View (`df`, [i#0, j#1]) : : +- RelationV2[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 : +- Aggregate [min(i#10) AS min(i)#9] : +- SubqueryAlias df : +- View (`df`, [i#10, j#11]) : +- RelationV2[i#10, j#11] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 +- OneRowRelation == Optimized Logical Plan == Project [scalar-subquery#2 [].max(i) AS max_i#3, scalar-subquery#4 [].min(i) AS min_i#5] : :- Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] : : +- Aggregate [max(i#0) AS max(i)#7, min(i#0) AS min(i)#9] : : +- Project [i#0] : : +- RelationV2[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 : +- Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] : +- Aggregate [max(i#0) AS max(i)#7, min(i#0) AS min(i)#9] : +- Project [i#0] : +- RelationV2[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5 +- OneRowRelation == Physical Plan == AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 0 +- *(1) Project [Subquery subquery#2, [id=#40].max(i) AS max_i#3, ReusedSubquery Subquery subquery#2, [id=#40].min(i) AS min_i#5] : :- Subquery subquery#2, [id=#40] : : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 1 +- *(2) Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] +- *(2) HashAggregate(keys=[], functions=[max(i#0), min(i#0)], output=[max(i)#7, min(i)#9]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=71] +- *(1) HashAggregate(keys=[], functions=[partial_max(i#0), partial_min(i#0)], output=[max#16, min#17]) +- *(1) Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- == Initial Plan == Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] +- HashAggregate(keys=[], functions=[max(i#0), min(i#0)], output=[max(i)#7, min(i)#9]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=22] +- HashAggregate(keys=[], functions=[partial_max(i#0), partial_min(i#0)], output=[max#16, min#17]) +- Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] : +- ReusedSubquery Subquery subquery#2, [id=#40] +- *(1) Scan OneRowRelation[] +- == Initial Plan == Project [Subquery subquery#2, [id=#40].max(i) AS max_i#3, Subquery subquery#4, [id=#41].min(i) AS min_i#5] : :- Subquery subquery#2, [id=#40] : : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 1 +- *(2) Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] +- *(2) HashAggregate(keys=[], functions=[max(i#0), min(i#0)], output=[max(i)#7, min(i)#9]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=71] +- *(1) HashAggregate(keys=[], functions=[partial_max(i#0), partial_min(i#0)], output=[max#16, min#17]) +- *(1) Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- == Initial Plan == Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] +- HashAggregate(keys=[], functions=[max(i#0), min(i#0)], output=[max(i)#7, min(i)#9]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=22] +- HashAggregate(keys=[], functions=[partial_max(i#0), partial_min(i#0)], output=[max#16, min#17]) +- Project [i#0] +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] : +- Subquery subquery#4, [id=#41] : +- AdaptiveSparkPlan isFinalPlan=false : +- Project [named_struct(max(i), max(i)#7, min(i), min(i)#9) AS mergedValue#14] : +- HashAggregate(keys=[], functions=[max(i#0), min(i#0)], output=[max(i)#7, min(i)#9]) : +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=37] : +- HashAggregate(keys=[], functions=[partial_max(i#0), partial_min(i#0)], output=[max#16, min#17]) : +- Project [i#0] : +- BatchScan class org.apache.spark.sql.connector.SimpleDataSourceV2$$anon$5[i#0, j#1] class org.apache.spark.sql.connector.SimpleDataSourceV2$MyScanBuilder RuntimeFilters: [] +- Scan OneRowRelation[] ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #52529 from yhuang-db/scan-canonicalization. Authored-by: yhuang-db <[email protected]> Signed-off-by: Peter Toth <[email protected]>
1 parent 928f253 commit 0732e44

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
2020
import org.apache.spark.SparkException
2121
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation, TimeTravelSpec}
2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder}
23+
import org.apache.spark.sql.catalyst.plans.QueryPlan
2324
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics}
2425
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
2526
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils}
@@ -175,6 +176,15 @@ case class DataSourceV2ScanRelation(
175176
Statistics(sizeInBytes = conf.defaultSizeInBytes)
176177
}
177178
}
179+
180+
override def doCanonicalize(): DataSourceV2ScanRelation = {
181+
this.copy(
182+
relation = this.relation.copy(
183+
output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output))
184+
),
185+
output = this.output.map(QueryPlan.normalizeExpressions(_, this.output))
186+
)
187+
}
178188
}
179189

180190
/**

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import test.org.apache.spark.sql.connector._
2626
import org.apache.spark.SparkUnsupportedOperationException
2727
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
2828
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
30+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
2931
import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider}
3032
import org.apache.spark.sql.connector.catalog.TableCapability._
3133
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
@@ -36,6 +38,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
3638
import org.apache.spark.sql.execution.SortExec
3739
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3840
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
41+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
3942
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
4043
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
4144
import org.apache.spark.sql.expressions.Window
@@ -976,6 +979,79 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
976979
assert(result.length == 1)
977980
}
978981
}
982+
983+
test("SPARK-53809: scan canonicalization") {
984+
val table = new SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty())
985+
986+
def createDsv2ScanRelation(): DataSourceV2ScanRelation = {
987+
val relation = DataSourceV2Relation.create(
988+
table, None, None, CaseInsensitiveStringMap.empty())
989+
val scan = relation.table.asReadable.newScanBuilder(relation.options).build()
990+
DataSourceV2ScanRelation(relation, scan, relation.output)
991+
}
992+
993+
// Create two DataSourceV2ScanRelation instances, representing the scan of the same table
994+
val scanRelation1 = createDsv2ScanRelation()
995+
val scanRelation2 = createDsv2ScanRelation()
996+
997+
// the two instances should not be the same, as they should have different attribute IDs
998+
assert(scanRelation1 != scanRelation2,
999+
"Two created DataSourceV2ScanRelation instances should not be the same")
1000+
assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet,
1001+
"Output attributes should have different expression IDs before canonicalization")
1002+
assert(scanRelation1.relation.output.map(_.exprId).toSet !=
1003+
scanRelation2.relation.output.map(_.exprId).toSet,
1004+
"Relation output attributes should have different expression IDs before canonicalization")
1005+
1006+
// After canonicalization, the two instances should be equal
1007+
assert(scanRelation1.canonicalized == scanRelation2.canonicalized,
1008+
"Canonicalized DataSourceV2ScanRelation instances should be equal")
1009+
}
1010+
1011+
test("SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation") {
1012+
val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load()
1013+
df.createOrReplaceTempView("df")
1014+
1015+
val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i")
1016+
val optimizedPlan = query.queryExecution.optimizedPlan
1017+
1018+
// check optimizedPlan merged scalar subqueries `select max(i), min(i) from df`
1019+
val sub1 = optimizedPlan.asInstanceOf[Project].projectList.head.collect {
1020+
case s: ScalarSubquery => s
1021+
}
1022+
val sub2 = optimizedPlan.asInstanceOf[Project].projectList(1).collect {
1023+
case s: ScalarSubquery => s
1024+
}
1025+
1026+
// Both subqueries should reference the same merged plan `select max(i), min(i) from df`
1027+
assert(sub1.nonEmpty && sub2.nonEmpty, "Both scalar subqueries should exist")
1028+
assert(sub1.head.plan == sub2.head.plan,
1029+
"Both subqueries should reference the same merged plan")
1030+
1031+
// Extract the aggregate from the merged plan sub1
1032+
val agg = sub1.head.plan.collect {
1033+
case a: Aggregate => a
1034+
}.head
1035+
1036+
// Check that the aggregate contains both max(i) and min(i)
1037+
val aggFunctionSet = agg.aggregateExpressions.flatMap { expr =>
1038+
expr.collect {
1039+
case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression =>
1040+
ae.aggregateFunction
1041+
}
1042+
}.toSet
1043+
1044+
assert(aggFunctionSet.size == 2, "Aggregate should contain exactly two aggregate functions")
1045+
assert(aggFunctionSet
1046+
.exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Max]),
1047+
"Aggregate should contain max(i)")
1048+
assert(aggFunctionSet
1049+
.exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Min]),
1050+
"Aggregate should contain min(i)")
1051+
1052+
// Verify the query produces correct results
1053+
checkAnswer(query, Row(9, 0))
1054+
}
9791055
}
9801056

9811057
case class RangeInputPartition(start: Int, end: Int) extends InputPartition
@@ -1081,6 +1157,18 @@ class SimpleDataSourceV2 extends TestingV2Source {
10811157
override def planInputPartitions(): Array[InputPartition] = {
10821158
Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
10831159
}
1160+
1161+
override def equals(obj: Any): Boolean = {
1162+
obj match {
1163+
case s: Scan =>
1164+
this.readSchema() == s.readSchema()
1165+
case _ => false
1166+
}
1167+
}
1168+
1169+
override def hashCode(): Int = {
1170+
this.readSchema().hashCode()
1171+
}
10841172
}
10851173

10861174
override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {

0 commit comments

Comments
 (0)