@@ -26,6 +26,8 @@ import test.org.apache.spark.sql.connector._
2626import org .apache .spark .SparkUnsupportedOperationException
2727import org .apache .spark .sql .{AnalysisException , DataFrame , QueryTest , Row }
2828import org .apache .spark .sql .catalyst .InternalRow
29+ import org .apache .spark .sql .catalyst .expressions .ScalarSubquery
30+ import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , Project }
2931import org .apache .spark .sql .connector .catalog .{PartitionInternalRow , SupportsRead , Table , TableCapability , TableProvider }
3032import org .apache .spark .sql .connector .catalog .TableCapability ._
3133import org .apache .spark .sql .connector .expressions .{Expression , FieldReference , Literal , NamedReference , NullOrdering , SortDirection , SortOrder , Transform }
@@ -36,6 +38,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
3638import org .apache .spark .sql .execution .SortExec
3739import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
3840import org .apache .spark .sql .execution .datasources .v2 .{BatchScanExec , DataSourceV2Relation , DataSourceV2ScanRelation }
41+ import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Implicits ._
3942import org .apache .spark .sql .execution .exchange .{Exchange , ShuffleExchangeExec }
4043import org .apache .spark .sql .execution .vectorized .OnHeapColumnVector
4144import org .apache .spark .sql .expressions .Window
@@ -976,6 +979,79 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
976979 assert(result.length == 1 )
977980 }
978981 }
982+
983+ test(" SPARK-53809: scan canonicalization" ) {
984+ val table = new SimpleDataSourceV2 ().getTable(CaseInsensitiveStringMap .empty())
985+
986+ def createDsv2ScanRelation (): DataSourceV2ScanRelation = {
987+ val relation = DataSourceV2Relation .create(
988+ table, None , None , CaseInsensitiveStringMap .empty())
989+ val scan = relation.table.asReadable.newScanBuilder(relation.options).build()
990+ DataSourceV2ScanRelation (relation, scan, relation.output)
991+ }
992+
993+ // Create two DataSourceV2ScanRelation instances, representing the scan of the same table
994+ val scanRelation1 = createDsv2ScanRelation()
995+ val scanRelation2 = createDsv2ScanRelation()
996+
997+ // the two instances should not be the same, as they should have different attribute IDs
998+ assert(scanRelation1 != scanRelation2,
999+ " Two created DataSourceV2ScanRelation instances should not be the same" )
1000+ assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet,
1001+ " Output attributes should have different expression IDs before canonicalization" )
1002+ assert(scanRelation1.relation.output.map(_.exprId).toSet !=
1003+ scanRelation2.relation.output.map(_.exprId).toSet,
1004+ " Relation output attributes should have different expression IDs before canonicalization" )
1005+
1006+ // After canonicalization, the two instances should be equal
1007+ assert(scanRelation1.canonicalized == scanRelation2.canonicalized,
1008+ " Canonicalized DataSourceV2ScanRelation instances should be equal" )
1009+ }
1010+
1011+ test(" SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation" ) {
1012+ val df = spark.read.format(classOf [SimpleDataSourceV2 ].getName).load()
1013+ df.createOrReplaceTempView(" df" )
1014+
1015+ val query = sql(" select (select max(i) from df) as max_i, (select min(i) from df) as min_i" )
1016+ val optimizedPlan = query.queryExecution.optimizedPlan
1017+
1018+ // check optimizedPlan merged scalar subqueries `select max(i), min(i) from df`
1019+ val sub1 = optimizedPlan.asInstanceOf [Project ].projectList.head.collect {
1020+ case s : ScalarSubquery => s
1021+ }
1022+ val sub2 = optimizedPlan.asInstanceOf [Project ].projectList(1 ).collect {
1023+ case s : ScalarSubquery => s
1024+ }
1025+
1026+ // Both subqueries should reference the same merged plan `select max(i), min(i) from df`
1027+ assert(sub1.nonEmpty && sub2.nonEmpty, " Both scalar subqueries should exist" )
1028+ assert(sub1.head.plan == sub2.head.plan,
1029+ " Both subqueries should reference the same merged plan" )
1030+
1031+ // Extract the aggregate from the merged plan sub1
1032+ val agg = sub1.head.plan.collect {
1033+ case a : Aggregate => a
1034+ }.head
1035+
1036+ // Check that the aggregate contains both max(i) and min(i)
1037+ val aggFunctionSet = agg.aggregateExpressions.flatMap { expr =>
1038+ expr.collect {
1039+ case ae : org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression =>
1040+ ae.aggregateFunction
1041+ }
1042+ }.toSet
1043+
1044+ assert(aggFunctionSet.size == 2 , " Aggregate should contain exactly two aggregate functions" )
1045+ assert(aggFunctionSet
1046+ .exists(_.isInstanceOf [org.apache.spark.sql.catalyst.expressions.aggregate.Max ]),
1047+ " Aggregate should contain max(i)" )
1048+ assert(aggFunctionSet
1049+ .exists(_.isInstanceOf [org.apache.spark.sql.catalyst.expressions.aggregate.Min ]),
1050+ " Aggregate should contain min(i)" )
1051+
1052+ // Verify the query produces correct results
1053+ checkAnswer(query, Row (9 , 0 ))
1054+ }
9791055}
9801056
9811057case class RangeInputPartition (start : Int , end : Int ) extends InputPartition
@@ -1081,6 +1157,18 @@ class SimpleDataSourceV2 extends TestingV2Source {
10811157 override def planInputPartitions (): Array [InputPartition ] = {
10821158 Array (RangeInputPartition (0 , 5 ), RangeInputPartition (5 , 10 ))
10831159 }
1160+
1161+ override def equals (obj : Any ): Boolean = {
1162+ obj match {
1163+ case s : Scan =>
1164+ this .readSchema() == s.readSchema()
1165+ case _ => false
1166+ }
1167+ }
1168+
1169+ override def hashCode (): Int = {
1170+ this .readSchema().hashCode()
1171+ }
10841172 }
10851173
10861174 override def getTable (options : CaseInsensitiveStringMap ): Table = new SimpleBatchTable {
0 commit comments