Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
override protected val blacklistedOnceBatches: Set[String] =
Set(
"PartitionPruning",
"Extract Python UDFs")
"Extract Python UDFs",
"Push predicate through join by CNF")

protected def fixedPoint =
FixedPoint(
Expand Down Expand Up @@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Infer Filters", Once,
InferFiltersFromConstraints) ::
Batch("Operator Optimization after Inferring Filters", fixedPoint,
rulesWithoutInferFiltersFromConstraints: _*) :: Nil
rulesWithoutInferFiltersFromConstraints: _*) ::
// Set strategy to Once to avoid pushing filter every time because we do not change the
// join condition.
Batch("Push predicate through join by CNF", Once,
PushPredicateThroughJoinByCNF) :: Nil
}

val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Expand Down Expand Up @@ -1372,6 +1377,96 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Rewriting join condition to conjunctive normal form expression so that we can push
* more predicate.
*/
object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it apply to all the predicates? like when we pushdown filters to the data source?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, such as the step of scanning web_sales table for TPC-DS q85.

Before this pr:

+- *(8) Filter (((isnotnull(ws_item_sk#3) AND isnotnull(ws_order_number#17)) AND isnotnull(ws_web_page_sk#12)) AND isnotnull(ws_sold_date_sk#0))
   +- *(8) ColumnarToRow
      +- FileScan parquet default.web_sales[ws_sold_date_sk#0,ws_item_sk#3,ws_web_page_sk#12,ws_order_number#17,ws_quantity#18,ws_sales_price#21,ws_net_profit#33] Batched: true, DataFilters: [isnotnull(ws_item_sk#3), isnotnull(ws_order_number#17), isnotnull(ws_web_page_sk#12), isnotnull(ws_sold_date_sk#0)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark.sql.TPCDSQuerySuite/web_sales], PartitionFilters: [], PushedFilters: [IsNotNull(ws_item_sk), IsNotNull(ws_order_number), IsNotNull(ws_web_page_sk), IsNotNull(ws_sold_date_sk)], ReadSchema: struct<ws_sold_date_sk:int,ws_item_sk:int,ws_web_page_sk:int,ws_order_number:int,ws_quantity:int,ws_sales_price:decimal(7,2),ws_net_profit:decimal(7,2)>

After this pr:

+- *(8) Filter (((((isnotnull(ws_item_sk#3) AND isnotnull(ws_order_number#17)) AND isnotnull(ws_web_page_sk#12)) AND isnotnull(ws_sold_date_sk#0)) AND ((((ws_sales_price#21 >= 100.00) AND (ws_sales_price#21 <= 150.00)) OR ((ws_sales_price#21 >= 50.00) AND (ws_sales_price#21 <= 100.00))) OR ((ws_sales_price#21 >= 150.00) AND (ws_sales_price#21 <= 200.00)))) AND ((((ws_net_profit#33 >= 100.00) AND (ws_net_profit#33 <= 200.00)) OR ((ws_net_profit#33 >= 150.00) AND (ws_net_profit#33 <= 300.00))) OR ((ws_net_profit#33 >= 50.00) AND (ws_net_profit#33 <= 250.00))))
   +- *(8) ColumnarToRow
      +- FileScan parquet default.web_sales[ws_sold_date_sk#0,ws_item_sk#3,ws_web_page_sk#12,ws_order_number#17,ws_quantity#18,ws_sales_price#21,ws_net_profit#33] Batched: true, DataFilters: [isnotnull(ws_item_sk#3), isnotnull(ws_order_number#17), isnotnull(ws_web_page_sk#12), isnotnull(ws_sold_date_sk#0), ((((ws_sales_price#21 >= 100.00) AND (ws_sales_price#21 <= 150.00)) OR ((ws_sales_price#21 >= 50.00) AND (ws_sales_price#21 <= 100.00))) OR ((ws_sales_price#21 >= 150.00) AND (ws_sales_price#21 <= 200.00))), ((((ws_net_profit#33 >= 100.00) AND (ws_net_profit#33 <= 200.00)) OR ((ws_net_profit#33 >= 150.00) AND (ws_net_profit#33 <= 300.00))) OR ((ws_net_profit#33 >= 50.00) AND (ws_net_profit#33 <= 250.00)))], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark.sql.TPCDSQuerySuite/web_sales], PartitionFilters: [], PushedFilters: [IsNotNull(ws_item_sk), IsNotNull(ws_order_number), IsNotNull(ws_web_page_sk), IsNotNull(ws_sold_date_sk), Or(Or(And(GreaterThanOrEqual(ws_sales_price,100.00),LessThanOrEqual(ws_sales_price,150.00)),And(GreaterThanOrEqual(ws_sales_price,50.00),LessThanOrEqual(ws_sales_price,100.00))),And(GreaterThanOrEqual(ws_sales_price,150.00),LessThanOrEqual(ws_sales_price,200.00))), Or(Or(And(GreaterThanOrEqual(ws_net_profit,100.00),LessThanOrEqual(ws_net_profit,200.00)),And(GreaterThanOrEqual(ws_net_profit,150.00),LessThanOrEqual(ws_net_profit,300.00))),And(GreaterThanOrEqual(ws_net_profit,50.00),LessThanOrEqual(ws_net_profit,250.00)))], ReadSchema: struct<ws_sold_date_sk:int,ws_item_sk:int,ws_web_page_sk:int,ws_order_number:int,ws_quantity:int,ws_sales_price:decimal(7,2),ws_net_profit:decimal(7,2)>

/**
* Rewrite pattern:
* 1. (a && b) || (c && d) --> (a || c) && (a || d) && (b || c) && (b && d)
* 2. (a && b) || c --> (a || c) && (b || c)
* 3. a || (b && c) --> (a || b) && (a || c)
*
* To avoid generating too many predicates, we first group the columns from the same table.
*/
private def toCNF(condition: Expression, depth: Int = 0): Expression = {
if (depth < SQLConf.get.maxRewritingCNFDepth) {
condition match {
case or @ Or(left: And, right: And) =>
val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
Copy link
Member Author

@wangyum wangyum May 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group by qualifier to avoid generating too many predicates. For example:
TPCDS q85:
Without group by qualifier:

== Physical Plan ==
TakeOrderedAndProject(limit=100, orderBy=[substr(r_reason_desc, 1, 20)#137 ASC NULLS FIRST,aggOrder#142 ASC NULLS FIRST,avg(wr_refunded_cash)#139 ASC NULLS FIRST,avg(wr_fee)#140 ASC NULLS FIRST], output=[substr(r_reason_desc, 1, 20)#137,avg(ws_quantity)#138,avg(wr_refunded_cash)#139,avg(wr_fee)#140])
+- *(9) HashAggregate(keys=[r_reason_desc#124], functions=[avg(cast(ws_quantity#18 as bigint)), avg(UnscaledValue(wr_refunded_cash#54)), avg(UnscaledValue(wr_fee#52))])
   +- Exchange hashpartitioning(r_reason_desc#124, 5), true, [id=#351]
      +- *(8) HashAggregate(keys=[r_reason_desc#124], functions=[partial_avg(cast(ws_quantity#18 as bigint)), partial_avg(UnscaledValue(wr_refunded_cash#54)), partial_avg(UnscaledValue(wr_fee#52))])
         +- *(8) Project [ws_quantity#18, wr_fee#52, wr_refunded_cash#54, r_reason_desc#124]
            +- *(8) BroadcastHashJoin [wr_reason_sk#46L], [cast(r_reason_sk#122 as bigint)], Inner, BuildRight
               :- *(8) Project [ws_quantity#18, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :  +- *(8) BroadcastHashJoin [ws_sold_date_sk#0], [d_date_sk#94], Inner, BuildRight
               :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :  +- *(8) BroadcastHashJoin [wr_refunded_addr_sk#40L], [cast(ca_address_sk#81 as bigint)], Inner, BuildRight, ((((ca_state#89 IN (IN,OH,NJ) AND (ws_net_profit#33 >= 100.00)) AND (ws_net_profit#33 <= 200.00)) OR ((ca_state#89 IN (WI,CT,KY) AND (ws_net_profit#33 >= 150.00)) AND (ws_net_profit#33 <= 300.00))) OR ((ca_state#89 IN (LA,IA,AR) AND (ws_net_profit#33 >= 50.00)) AND (ws_net_profit#33 <= 250.00)))
               :     :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, ws_net_profit#33, wr_refunded_addr_sk#40L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :     :  +- *(8) BroadcastHashJoin [wr_returning_cdemo_sk#42L, cd_marital_status#74, cd_education_status#75], [cast(cd_demo_sk#125 as bigint), cd_marital_status#127, cd_education_status#128], Inner, BuildRight
               :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, ws_net_profit#33, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54, cd_marital_status#74, cd_education_status#75]
               :     :     :     :  +- *(8) BroadcastHashJoin [wr_refunded_cdemo_sk#38L], [cast(cd_demo_sk#72 as bigint)], Inner, BuildRight, ((((((cd_marital_status#74 = M) AND (cd_education_status#75 = Advanced Degree)) AND (ws_sales_price#21 >= 100.00)) AND (ws_sales_price#21 <= 150.00)) OR ((((cd_marital_status#74 = S) AND (cd_education_status#75 = College)) AND (ws_sales_price#21 >= 50.00)) AND (ws_sales_price#21 <= 100.00))) OR ((((cd_marital_status#74 = W) AND (cd_education_status#75 = 2 yr Degree)) AND (ws_sales_price#21 >= 150.00)) AND (ws_sales_price#21 <= 200.00)))
               :     :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, ws_sales_price#21, ws_net_profit#33, wr_refunded_cdemo_sk#38L, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :     :     :     :  +- *(8) BroadcastHashJoin [ws_web_page_sk#12], [wp_web_page_sk#58], Inner, BuildRight
               :     :     :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_web_page_sk#12, ws_quantity#18, ws_sales_price#21, ws_net_profit#33, wr_refunded_cdemo_sk#38L, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :     :     :     :     :  +- *(8) BroadcastHashJoin [cast(ws_item_sk#3 as bigint), cast(ws_order_number#17 as bigint)], [wr_item_sk#36L, wr_order_number#47L], Inner, BuildRight
               :     :     :     :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_item_sk#3, ws_web_page_sk#12, ws_order_number#17, ws_quantity#18, ws_sales_price#21, ws_net_profit#33]
               :     :     :     :     :     :     :  +- *(8) Filter (((((((((((((((((((isnotnull(ws_item_sk#3) AND isnotnull(ws_order_number#17)) AND isnotnull(ws_web_page_sk#12)) AND isnotnull(ws_sold_date_sk#0)) AND (((ws_sales_price#21 >= 100.00) OR (ws_sales_price#21 >= 50.00)) OR (ws_sales_price#21 >= 150.00))) AND (((ws_sales_price#21 >= 100.00) OR (ws_sales_price#21 <= 100.00)) OR (ws_sales_price#21 >= 150.00))) AND (((ws_sales_price#21 <= 150.00) OR (ws_sales_price#21 >= 50.00)) OR (ws_sales_price#21 >= 150.00))) AND (((ws_sales_price#21 <= 150.00) OR (ws_sales_price#21 <= 100.00)) OR (ws_sales_price#21 >= 150.00))) AND (((ws_sales_price#21 >= 100.00) OR (ws_sales_price#21 >= 50.00)) OR (ws_sales_price#21 <= 200.00))) AND (((ws_sales_price#21 >= 100.00) OR (ws_sales_price#21 <= 100.00)) OR (ws_sales_price#21 <= 200.00))) AND (((ws_sales_price#21 <= 150.00) OR (ws_sales_price#21 >= 50.00)) OR (ws_sales_price#21 <= 200.00))) AND (((ws_sales_price#21 <= 150.00) OR (ws_sales_price#21 <= 100.00)) OR (ws_sales_price#21 <= 200.00))) AND (((ws_net_profit#33 >= 100.00) OR (ws_net_profit#33 >= 150.00)) OR (ws_net_profit#33 >= 50.00))) AND (((ws_net_profit#33 >= 100.00) OR (ws_net_profit#33 <= 300.00)) OR (ws_net_profit#33 >= 50.00))) AND (((ws_net_profit#33 <= 200.00) OR (ws_net_profit#33 >= 150.00)) OR (ws_net_profit#33 >= 50.00))) AND (((ws_net_profit#33 <= 200.00) OR (ws_net_profit#33 <= 300.00)) OR (ws_net_profit#33 >= 50.00))) AND (((ws_net_profit#33 >= 100.00) OR (ws_net_profit#33 >= 150.00)) OR (ws_net_profit#33 <= 250.00))) AND (((ws_net_profit#33 >= 100.00) OR (ws_net_profit#33 <= 300.00)) OR (ws_net_profit#33 <= 250.00))) AND (((ws_net_profit#33 <= 200.00) OR (ws_net_profit#33 >= 150.00)) OR (ws_net_profit#33 <= 250.00))) AND (((ws_net_profit#33 <= 200.00) OR (ws_net_profit#33 <= 300.00)) OR (ws_net_profit#33 <= 250.00)))
               :     :     :     :     :     :     :     +- *(8) ColumnarToRow
               :     :     :     :     :     :     :        +- FileScan parquet default.web_sales[ws_sold_date_sk#0,ws_item_sk#3,ws_web_page_sk#12,ws_order_number#17,ws_quantity#18,ws_sales_price#21,ws_net_profit#33] Batched: true, DataFilters: [isnotnull(ws_item_sk#3), isnotnull(ws_order_number#17), isnotnull(ws_web_page_sk#12), isnotnull(..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(ws_item_sk), IsNotNull(ws_order_number), IsNotNull(ws_web_page_sk), IsNotNull(ws_sold_..., ReadSchema: struct<ws_sold_date_sk:int,ws_item_sk:int,ws_web_page_sk:int,ws_order_number:int,ws_quantity:int,...
               :     :     :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, true], input[5, bigint, true])), [id=#291]
               :     :     :     :     :     :        +- *(1) Project [wr_item_sk#36L, wr_refunded_cdemo_sk#38L, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_order_number#47L, wr_fee#52, wr_refunded_cash#54]
               :     :     :     :     :     :           +- *(1) Filter (((((isnotnull(wr_item_sk#36L) AND isnotnull(wr_order_number#47L)) AND isnotnull(wr_refunded_cdemo_sk#38L)) AND isnotnull(wr_returning_cdemo_sk#42L)) AND isnotnull(wr_refunded_addr_sk#40L)) AND isnotnull(wr_reason_sk#46L))
               :     :     :     :     :     :              +- *(1) ColumnarToRow
               :     :     :     :     :     :                 +- FileScan parquet default.web_returns[wr_item_sk#36L,wr_refunded_cdemo_sk#38L,wr_refunded_addr_sk#40L,wr_returning_cdemo_sk#42L,wr_reason_sk#46L,wr_order_number#47L,wr_fee#52,wr_refunded_cash#54] Batched: true, DataFilters: [isnotnull(wr_item_sk#36L), isnotnull(wr_order_number#47L), isnotnull(wr_refunded_cdemo_sk#38L), ..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(wr_item_sk), IsNotNull(wr_order_number), IsNotNull(wr_refunded_cdemo_sk), IsNotNull(wr..., ReadSchema: struct<wr_item_sk:bigint,wr_refunded_cdemo_sk:bigint,wr_refunded_addr_sk:bigint,wr_returning_cdem...
               :     :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#300]
               :     :     :     :     :        +- *(2) Project [wp_web_page_sk#58]
               :     :     :     :     :           +- *(2) Filter isnotnull(wp_web_page_sk#58)
               :     :     :     :     :              +- *(2) ColumnarToRow
               :     :     :     :     :                 +- FileScan parquet default.web_page[wp_web_page_sk#58] Batched: true, DataFilters: [isnotnull(wp_web_page_sk#58)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(wp_web_page_sk)], ReadSchema: struct<wp_web_page_sk:int>
               :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#309]
               :     :     :     :        +- *(3) Project [cd_demo_sk#72, cd_marital_status#74, cd_education_status#75]
               :     :     :     :           +- *(3) Filter ((((((((((isnotnull(cd_demo_sk#72) AND isnotnull(cd_education_status#75)) AND isnotnull(cd_marital_status#74)) AND (((cd_marital_status#74 = M) OR (cd_marital_status#74 = S)) OR (cd_marital_status#74 = W))) AND (((cd_marital_status#74 = M) OR (cd_marital_status#74 = S)) OR (cd_education_status#75 = 2 yr Degree))) AND (((cd_marital_status#74 = M) OR (cd_education_status#75 = College)) OR (cd_marital_status#74 = W))) AND (((cd_marital_status#74 = M) OR (cd_education_status#75 = College)) OR (cd_education_status#75 = 2 yr Degree))) AND (((cd_education_status#75 = Advanced Degree) OR (cd_marital_status#74 = S)) OR (cd_marital_status#74 = W))) AND (((cd_education_status#75 = Advanced Degree) OR (cd_marital_status#74 = S)) OR (cd_education_status#75 = 2 yr Degree))) AND (((cd_education_status#75 = Advanced Degree) OR (cd_education_status#75 = College)) OR (cd_marital_status#74 = W))) AND (((cd_education_status#75 = Advanced Degree) OR (cd_education_status#75 = College)) OR (cd_education_status#75 = 2 yr Degree)))
               :     :     :     :              +- *(3) ColumnarToRow
               :     :     :     :                 +- FileScan parquet default.customer_demographics[cd_demo_sk#72,cd_marital_status#74,cd_education_status#75] Batched: true, DataFilters: [isnotnull(cd_demo_sk#72), isnotnull(cd_education_status#75), isnotnull(cd_marital_status#74), ((..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(cd_demo_sk), IsNotNull(cd_education_status), IsNotNull(cd_marital_status), Or(Or(Equal..., ReadSchema: struct<cd_demo_sk:int,cd_marital_status:string,cd_education_status:string>
               :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint), input[1, string, true], input[2, string, true])), [id=#318]
               :     :     :        +- *(4) Project [cd_demo_sk#125, cd_marital_status#127, cd_education_status#128]
               :     :     :           +- *(4) Filter ((isnotnull(cd_demo_sk#125) AND isnotnull(cd_education_status#128)) AND isnotnull(cd_marital_status#127))
               :     :     :              +- *(4) ColumnarToRow
               :     :     :                 +- FileScan parquet default.customer_demographics[cd_demo_sk#125,cd_marital_status#127,cd_education_status#128] Batched: true, DataFilters: [isnotnull(cd_demo_sk#125), isnotnull(cd_education_status#128), isnotnull(cd_marital_status#127)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(cd_demo_sk), IsNotNull(cd_education_status), IsNotNull(cd_marital_status)], ReadSchema: struct<cd_demo_sk:int,cd_marital_status:string,cd_education_status:string>
               :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#327]
               :     :        +- *(5) Project [ca_address_sk#81, ca_state#89]
               :     :           +- *(5) Filter (((isnotnull(ca_country#91) AND (ca_country#91 = United States)) AND isnotnull(ca_address_sk#81)) AND ((ca_state#89 IN (IN,OH,NJ) OR ca_state#89 IN (WI,CT,KY)) OR ca_state#89 IN (LA,IA,AR)))
               :     :              +- *(5) ColumnarToRow
               :     :                 +- FileScan parquet default.customer_address[ca_address_sk#81,ca_state#89,ca_country#91] Batched: true, DataFilters: [isnotnull(ca_country#91), (ca_country#91 = United States), isnotnull(ca_address_sk#81), ((ca_sta..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(ca_country), EqualTo(ca_country,United States), IsNotNull(ca_address_sk), Or(Or(In(ca_..., ReadSchema: struct<ca_address_sk:int,ca_state:string,ca_country:string>
               :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#336]
               :        +- *(6) Project [d_date_sk#94]
               :           +- *(6) Filter ((isnotnull(d_year#100) AND (d_year#100 = 2000)) AND isnotnull(d_date_sk#94))
               :              +- *(6) ColumnarToRow
               :                 +- FileScan parquet default.date_dim[d_date_sk#94,d_year#100] Batched: true, DataFilters: [isnotnull(d_year#100), (d_year#100 = 2000), isnotnull(d_date_sk#94)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(d_year), EqualTo(d_year,2000), IsNotNull(d_date_sk)], ReadSchema: struct<d_date_sk:int,d_year:int>
               +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#345]
                  +- *(7) Project [r_reason_sk#122, r_reason_desc#124]
                     +- *(7) Filter isnotnull(r_reason_sk#122)
                        +- *(7) ColumnarToRow
                           +- FileScan parquet default.reason[r_reason_sk#122,r_reason_desc#124] Batched: true, DataFilters: [isnotnull(r_reason_sk#122)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(r_reason_sk)], ReadSchema: struct<r_reason_sk:int,r_reason_desc:string>

Group by qualifier:

== Physical Plan ==
TakeOrderedAndProject(limit=100, orderBy=[substr(r_reason_desc, 1, 20)#137 ASC NULLS FIRST,aggOrder#142 ASC NULLS FIRST,avg(wr_refunded_cash)#139 ASC NULLS FIRST,avg(wr_fee)#140 ASC NULLS FIRST], output=[substr(r_reason_desc, 1, 20)#137,avg(ws_quantity)#138,avg(wr_refunded_cash)#139,avg(wr_fee)#140])
+- *(9) HashAggregate(keys=[r_reason_desc#124], functions=[avg(cast(ws_quantity#18 as bigint)), avg(UnscaledValue(wr_refunded_cash#54)), avg(UnscaledValue(wr_fee#52))])
   +- Exchange hashpartitioning(r_reason_desc#124, 5), true, [id=#351]
      +- *(8) HashAggregate(keys=[r_reason_desc#124], functions=[partial_avg(cast(ws_quantity#18 as bigint)), partial_avg(UnscaledValue(wr_refunded_cash#54)), partial_avg(UnscaledValue(wr_fee#52))])
         +- *(8) Project [ws_quantity#18, wr_fee#52, wr_refunded_cash#54, r_reason_desc#124]
            +- *(8) BroadcastHashJoin [wr_reason_sk#46L], [cast(r_reason_sk#122 as bigint)], Inner, BuildRight
               :- *(8) Project [ws_quantity#18, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :  +- *(8) BroadcastHashJoin [ws_sold_date_sk#0], [d_date_sk#94], Inner, BuildRight
               :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :  +- *(8) BroadcastHashJoin [wr_refunded_addr_sk#40L], [cast(ca_address_sk#81 as bigint)], Inner, BuildRight, ((((ca_state#89 IN (IN,OH,NJ) AND (ws_net_profit#33 >= 100.00)) AND (ws_net_profit#33 <= 200.00)) OR ((ca_state#89 IN (WI,CT,KY) AND (ws_net_profit#33 >= 150.00)) AND (ws_net_profit#33 <= 300.00))) OR ((ca_state#89 IN (LA,IA,AR) AND (ws_net_profit#33 >= 50.00)) AND (ws_net_profit#33 <= 250.00)))
               :     :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, ws_net_profit#33, wr_refunded_addr_sk#40L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :     :  +- *(8) BroadcastHashJoin [wr_returning_cdemo_sk#42L, cd_marital_status#74, cd_education_status#75], [cast(cd_demo_sk#125 as bigint), cd_marital_status#127, cd_education_status#128], Inner, BuildRight
               :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, ws_net_profit#33, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54, cd_marital_status#74, cd_education_status#75]
               :     :     :     :  +- *(8) BroadcastHashJoin [wr_refunded_cdemo_sk#38L], [cast(cd_demo_sk#72 as bigint)], Inner, BuildRight, ((((((cd_marital_status#74 = M) AND (cd_education_status#75 = Advanced Degree)) AND (ws_sales_price#21 >= 100.00)) AND (ws_sales_price#21 <= 150.00)) OR ((((cd_marital_status#74 = S) AND (cd_education_status#75 = College)) AND (ws_sales_price#21 >= 50.00)) AND (ws_sales_price#21 <= 100.00))) OR ((((cd_marital_status#74 = W) AND (cd_education_status#75 = 2 yr Degree)) AND (ws_sales_price#21 >= 150.00)) AND (ws_sales_price#21 <= 200.00)))
               :     :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_quantity#18, ws_sales_price#21, ws_net_profit#33, wr_refunded_cdemo_sk#38L, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :     :     :     :  +- *(8) BroadcastHashJoin [ws_web_page_sk#12], [wp_web_page_sk#58], Inner, BuildRight
               :     :     :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_web_page_sk#12, ws_quantity#18, ws_sales_price#21, ws_net_profit#33, wr_refunded_cdemo_sk#38L, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_fee#52, wr_refunded_cash#54]
               :     :     :     :     :     :  +- *(8) BroadcastHashJoin [cast(ws_item_sk#3 as bigint), cast(ws_order_number#17 as bigint)], [wr_item_sk#36L, wr_order_number#47L], Inner, BuildRight
               :     :     :     :     :     :     :- *(8) Project [ws_sold_date_sk#0, ws_item_sk#3, ws_web_page_sk#12, ws_order_number#17, ws_quantity#18, ws_sales_price#21, ws_net_profit#33]
               :     :     :     :     :     :     :  +- *(8) Filter (((((isnotnull(ws_item_sk#3) AND isnotnull(ws_order_number#17)) AND isnotnull(ws_web_page_sk#12)) AND isnotnull(ws_sold_date_sk#0)) AND ((((ws_sales_price#21 >= 100.00) AND (ws_sales_price#21 <= 150.00)) OR ((ws_sales_price#21 >= 50.00) AND (ws_sales_price#21 <= 100.00))) OR ((ws_sales_price#21 >= 150.00) AND (ws_sales_price#21 <= 200.00)))) AND ((((ws_net_profit#33 >= 100.00) AND (ws_net_profit#33 <= 200.00)) OR ((ws_net_profit#33 >= 150.00) AND (ws_net_profit#33 <= 300.00))) OR ((ws_net_profit#33 >= 50.00) AND (ws_net_profit#33 <= 250.00))))
               :     :     :     :     :     :     :     +- *(8) ColumnarToRow
               :     :     :     :     :     :     :        +- FileScan parquet default.web_sales[ws_sold_date_sk#0,ws_item_sk#3,ws_web_page_sk#12,ws_order_number#17,ws_quantity#18,ws_sales_price#21,ws_net_profit#33] Batched: true, DataFilters: [isnotnull(ws_item_sk#3), isnotnull(ws_order_number#17), isnotnull(ws_web_page_sk#12), isnotnull(..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(ws_item_sk), IsNotNull(ws_order_number), IsNotNull(ws_web_page_sk), IsNotNull(ws_sold_..., ReadSchema: struct<ws_sold_date_sk:int,ws_item_sk:int,ws_web_page_sk:int,ws_order_number:int,ws_quantity:int,...
               :     :     :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, true], input[5, bigint, true])), [id=#291]
               :     :     :     :     :     :        +- *(1) Project [wr_item_sk#36L, wr_refunded_cdemo_sk#38L, wr_refunded_addr_sk#40L, wr_returning_cdemo_sk#42L, wr_reason_sk#46L, wr_order_number#47L, wr_fee#52, wr_refunded_cash#54]
               :     :     :     :     :     :           +- *(1) Filter (((((isnotnull(wr_item_sk#36L) AND isnotnull(wr_order_number#47L)) AND isnotnull(wr_refunded_cdemo_sk#38L)) AND isnotnull(wr_returning_cdemo_sk#42L)) AND isnotnull(wr_refunded_addr_sk#40L)) AND isnotnull(wr_reason_sk#46L))
               :     :     :     :     :     :              +- *(1) ColumnarToRow
               :     :     :     :     :     :                 +- FileScan parquet default.web_returns[wr_item_sk#36L,wr_refunded_cdemo_sk#38L,wr_refunded_addr_sk#40L,wr_returning_cdemo_sk#42L,wr_reason_sk#46L,wr_order_number#47L,wr_fee#52,wr_refunded_cash#54] Batched: true, DataFilters: [isnotnull(wr_item_sk#36L), isnotnull(wr_order_number#47L), isnotnull(wr_refunded_cdemo_sk#38L), ..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(wr_item_sk), IsNotNull(wr_order_number), IsNotNull(wr_refunded_cdemo_sk), IsNotNull(wr..., ReadSchema: struct<wr_item_sk:bigint,wr_refunded_cdemo_sk:bigint,wr_refunded_addr_sk:bigint,wr_returning_cdem...
               :     :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#300]
               :     :     :     :     :        +- *(2) Project [wp_web_page_sk#58]
               :     :     :     :     :           +- *(2) Filter isnotnull(wp_web_page_sk#58)
               :     :     :     :     :              +- *(2) ColumnarToRow
               :     :     :     :     :                 +- FileScan parquet default.web_page[wp_web_page_sk#58] Batched: true, DataFilters: [isnotnull(wp_web_page_sk#58)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(wp_web_page_sk)], ReadSchema: struct<wp_web_page_sk:int>
               :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#309]
               :     :     :     :        +- *(3) Project [cd_demo_sk#72, cd_marital_status#74, cd_education_status#75]
               :     :     :     :           +- *(3) Filter (((isnotnull(cd_demo_sk#72) AND isnotnull(cd_education_status#75)) AND isnotnull(cd_marital_status#74)) AND ((((cd_marital_status#74 = M) AND (cd_education_status#75 = Advanced Degree)) OR ((cd_marital_status#74 = S) AND (cd_education_status#75 = College))) OR ((cd_marital_status#74 = W) AND (cd_education_status#75 = 2 yr Degree))))
               :     :     :     :              +- *(3) ColumnarToRow
               :     :     :     :                 +- FileScan parquet default.customer_demographics[cd_demo_sk#72,cd_marital_status#74,cd_education_status#75] Batched: true, DataFilters: [isnotnull(cd_demo_sk#72), isnotnull(cd_education_status#75), isnotnull(cd_marital_status#74), ((..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(cd_demo_sk), IsNotNull(cd_education_status), IsNotNull(cd_marital_status), Or(Or(And(E..., ReadSchema: struct<cd_demo_sk:int,cd_marital_status:string,cd_education_status:string>
               :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint), input[1, string, true], input[2, string, true])), [id=#318]
               :     :     :        +- *(4) Project [cd_demo_sk#125, cd_marital_status#127, cd_education_status#128]
               :     :     :           +- *(4) Filter ((isnotnull(cd_demo_sk#125) AND isnotnull(cd_education_status#128)) AND isnotnull(cd_marital_status#127))
               :     :     :              +- *(4) ColumnarToRow
               :     :     :                 +- FileScan parquet default.customer_demographics[cd_demo_sk#125,cd_marital_status#127,cd_education_status#128] Batched: true, DataFilters: [isnotnull(cd_demo_sk#125), isnotnull(cd_education_status#128), isnotnull(cd_marital_status#127)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(cd_demo_sk), IsNotNull(cd_education_status), IsNotNull(cd_marital_status)], ReadSchema: struct<cd_demo_sk:int,cd_marital_status:string,cd_education_status:string>
               :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#327]
               :     :        +- *(5) Project [ca_address_sk#81, ca_state#89]
               :     :           +- *(5) Filter (((isnotnull(ca_country#91) AND (ca_country#91 = United States)) AND isnotnull(ca_address_sk#81)) AND ((ca_state#89 IN (IN,OH,NJ) OR ca_state#89 IN (WI,CT,KY)) OR ca_state#89 IN (LA,IA,AR)))
               :     :              +- *(5) ColumnarToRow
               :     :                 +- FileScan parquet default.customer_address[ca_address_sk#81,ca_state#89,ca_country#91] Batched: true, DataFilters: [isnotnull(ca_country#91), (ca_country#91 = United States), isnotnull(ca_address_sk#81), ((ca_sta..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(ca_country), EqualTo(ca_country,United States), IsNotNull(ca_address_sk), Or(Or(In(ca_..., ReadSchema: struct<ca_address_sk:int,ca_state:string,ca_country:string>
               :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#336]
               :        +- *(6) Project [d_date_sk#94]
               :           +- *(6) Filter ((isnotnull(d_year#100) AND (d_year#100 = 2000)) AND isnotnull(d_date_sk#94))
               :              +- *(6) ColumnarToRow
               :                 +- FileScan parquet default.date_dim[d_date_sk#94,d_year#100] Batched: true, DataFilters: [isnotnull(d_year#100), (d_year#100 = 2000), isnotnull(d_date_sk#94)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(d_year), EqualTo(d_year,2000), IsNotNull(d_date_sk)], ReadSchema: struct<d_date_sk:int,d_year:int>
               +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#345]
                  +- *(7) Project [r_reason_sk#122, r_reason_desc#124]
                     +- *(7) Filter isnotnull(r_reason_sk#122)
                        +- *(7) ColumnarToRow
                           +- FileScan parquet default.reason[r_reason_sk#122,r_reason_desc#124] Batched: true, DataFilters: [isnotnull(r_reason_sk#122)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/spark/SPARK-28216/sql/core/spark-warehouse/org.apache.spark..., PartitionFilters: [], PushedFilters: [IsNotNull(r_reason_sk)], ReadSchema: struct<r_reason_sk:int,r_reason_desc:string>


val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
if (lhs.size > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we pick rhs if it has more conjunctive predicates?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We will pick it at case or @ Or(left, right: And) => or case or @ Or(left: And, right: And) =>.
E.g.: (a && b) || (c && d). The rewriting steps are:
(a && b) || (c && d) --> (a || (c && d)) && (b || (c && d)) --> (a || c) && (a || d) && (b || c) && (b && d).

We will pick it at case or @ Or(left, right: And) => if a is fixed .

lhs.values.map(_.reduceLeft(And)).map { e =>
toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1)
}.reduce(And)
} else if (rhs.size > 1) {
rhs.values.map(_.reduceLeft(And)).map { e =>
toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1)
}.reduce(And)
} else {
or
}

case or @ Or(left: And, right) =>
val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
if (lhs.size > 1) {
lhs.values.map(_.reduceLeft(And)).map { e =>
toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1)
}.reduce(And)
} else {
or
}

case or @ Or(left, right: And) =>
val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
if (rhs.size > 1) {
rhs.values.map(_.reduceLeft(And)).map { e =>
toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1)
}.reduce(And)
} else {
or
}

case And(left, right) =>
And(toCNF(left, depth + 1), toCNF(right, depth + 1))

case other =>
other
}
} else {
condition
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
val pushDownCandidates =
splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic)
val (leftFilterConditions, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightFilterConditions, _) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))

val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)

joinType match {
case _: InnerLike | LeftSemi =>
Join(newLeft, newRight, joinType, Some(joinCondition), hint)
case RightOuter =>
Join(newLeft, right, RightOuter, Some(joinCondition), hint)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
Join(left, newRight, joinType, Some(joinCondition), hint)
case FullOuter => j
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}
}
}

/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_REWRITING_CNF_DEPTH =
buildConf("spark.sql.maxRewritingCNFDepth")
.internal()
.doc("The maximum depth of rewriting a join condition to conjunctive normal form " +
"expression. The deeper, the more predicate may be found, but the optimization time will " +
"increase. The default is 10. By setting this value to 0 this feature can be disabled.")
.version("3.1.0")
.intConf
.checkValue(_ >= 0,
"The depth of the maximum rewriting conjunction normal form must be positive.")
.createWithDefault(10)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
Expand Down Expand Up @@ -2845,6 +2857,8 @@ class SQLConf extends Serializable with Logging {

def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

def maxRewritingCNFDepth: Int = getConf(MAX_REWRITING_CNF_DEPTH)

def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)

def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.unsafe.types.CalendarInterval

class FilterPushdownSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {

override protected val blacklistedOnceBatches: Set[String] =
Set("Push predicate through join by CNF")

val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Expand All @@ -39,7 +44,9 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughNonJoin,
BooleanSimplification,
PushPredicateThroughJoin,
CollapseProject) :: Nil
CollapseProject) ::
Batch("Push predicate through join by CNF", Once,
PushPredicateThroughJoinByCNF) :: Nil
}

val attrA = 'a.int
Expand Down Expand Up @@ -1230,4 +1237,154 @@ class FilterPushdownSuite extends PlanTest {

comparePlans(Optimize.execute(query.analyze), expected)
}

test("inner join: rewrite filter predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y)
.where(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x)
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
val correctAnswer =
left.join(right, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
}

test("inner join: rewrite join predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
val correctAnswer =
left.join(right, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match the PostgreSQL's plan:

postgres=# explain select x.* from x join y on (x.b = y.b and ( (x.a > 3 and y.a > 13 ) or (x.a > 1 and y.a > 11) ));
                                QUERY PLAN
---------------------------------------------------------------------------
 Merge Join  (cost=178.36..315.60 rows=3593 width=16)
   Merge Cond: (x.b = y.b)
   Join Filter: (((x.a > 3) AND (y.a > 13)) OR ((x.a > 1) AND (y.a > 11)))
   ->  Sort  (cost=89.18..91.75 rows=1028 width=16)
         Sort Key: x.b
         ->  Seq Scan on x  (cost=0.00..37.75 rows=1028 width=16)
               Filter: ((a > 3) OR (a > 1))
   ->  Sort  (cost=89.18..91.75 rows=1028 width=8)
         Sort Key: y.b
         ->  Seq Scan on y  (cost=0.00..37.75 rows=1028 width=8)
               Filter: ((a > 13) OR (a > 11))
(11 rows)

}

test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
&& Not(("x.a".attr > 3)
&& ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x)
val right = testRelation.subquery('y)
val correctAnswer =
left.join(right, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13)))
&& (("x.a".attr <= 1) || ("y.a".attr <= 11))))
.analyze
comparePlans(optimized, correctAnswer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match the PostgreSQL's plan:

postgres=# explain select x.* from x join y on ((x.b = y.b) and Not((x.a > 3) and (x.a < 2 or (y.a > 13)) or (x.a > 1) and (y.a > 11)));
                                          QUERY PLAN
-----------------------------------------------------------------------------------------------
 Merge Join  (cost=218.07..484.71 rows=3874 width=16)
   Merge Cond: (x.b = y.b)
   Join Filter: (((x.a <= 1) OR (y.a <= 11)) AND ((x.a <= 3) OR ((x.a >= 2) AND (y.a <= 13))))
   ->  Sort  (cost=89.18..91.75 rows=1028 width=16)
         Sort Key: x.b
         ->  Seq Scan on x  (cost=0.00..37.75 rows=1028 width=16)
               Filter: ((a <= 3) OR (a >= 2))
   ->  Sort  (cost=128.89..133.52 rows=1850 width=8)
         Sort Key: y.b
         ->  Seq Scan on y  (cost=0.00..28.50 rows=1850 width=8)
(10 rows)

}

test("left join: rewrite join predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.subquery('x)
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
val correctAnswer =
left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match the PostgreSQL's plan:

postgres=# explain select x.* from x left join y on (x.b = y.b and ( (x.a > 3 and y.a > 13 ) or (x.a > 1 and y.a > 11) ));
                                QUERY PLAN
---------------------------------------------------------------------------
 Merge Left Join  (cost=218.07..465.05 rows=1996 width=16)
   Merge Cond: (x.b = y.b)
   Join Filter: (((x.a > 3) AND (y.a > 13)) OR ((x.a > 1) AND (y.a > 11)))
   ->  Sort  (cost=128.89..133.52 rows=1850 width=16)
         Sort Key: x.b
         ->  Seq Scan on x  (cost=0.00..28.50 rows=1850 width=16)
   ->  Sort  (cost=89.18..91.75 rows=1028 width=8)
         Sort Key: y.b
         ->  Seq Scan on y  (cost=0.00..37.75 rows=1028 width=8)
               Filter: ((a > 13) OR (a > 11))
(10 rows)

}

test("right join: rewrite join predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
val right = testRelation.subquery('y)
val correctAnswer =
left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match the PostgreSQL's plan:

postgres=# explain select x.* from x right join y on (x.b = y.b and ( (x.a > 3 and y.a > 13 ) or (x.a > 1 and y.a > 11) ));
                                QUERY PLAN
---------------------------------------------------------------------------
 Merge Left Join  (cost=218.07..465.05 rows=1996 width=16)
   Merge Cond: (y.b = x.b)
   Join Filter: (((x.a > 3) AND (y.a > 13)) OR ((x.a > 1) AND (y.a > 11)))
   ->  Sort  (cost=128.89..133.52 rows=1850 width=8)
         Sort Key: y.b
         ->  Seq Scan on y  (cost=0.00..28.50 rows=1850 width=8)
   ->  Sort  (cost=89.18..91.75 rows=1028 width=16)
         Sort Key: x.b
         ->  Seq Scan on x  (cost=0.00..37.75 rows=1028 width=16)
               Filter: ((a > 3) OR (a > 1))
(10 rows)

}

test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.subquery('x)
val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
|| (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze

comparePlans(optimized, correctAnswer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match the PostgreSQL's plan:

postgres=# explain select x.* from x join y on (x.b = y.b and ( (x.a > 3 and x.a < 13 and y.c <= 5) or (y.a > 2 and y.c < 1) ));
                                       QUERY PLAN
-----------------------------------------------------------------------------------------
 Merge Join  (cost=207.30..402.86 rows=1927 width=16)
   Merge Cond: (y.b = x.b)
   Join Filter: (((x.a > 3) AND (x.a < 13) AND (y.c <= 5)) OR ((y.a > 2) AND (y.c < 1)))
   ->  Sort  (cost=78.41..80.30 rows=754 width=12)
         Sort Key: y.b
         ->  Seq Scan on y  (cost=0.00..42.38 rows=754 width=12)
               Filter: ((c <= 5) OR ((a > 2) AND (c < 1)))
   ->  Sort  (cost=128.89..133.52 rows=1850 width=16)
         Sort Key: x.b
         ->  Seq Scan on x  (cost=0.00..28.50 rows=1850 width=16)
(10 rows)

}

test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_REWRITING_CNF_DEPTH.key}=0") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
}

Seq(0, 10).foreach { depth =>
withSQLConf(SQLConf.MAX_REWRITING_CNF_DEPTH.key -> depth.toString) {
val optimized = Optimize.execute(originalQuery.analyze)
val (left, right) = if (depth == 0) {
(testRelation.subquery('x), testRelation.subquery('y))
} else {
(testRelation.subquery('x),
testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y))
}
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
|| (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze

comparePlans(optimized, correctAnswer)
}
}
}
}