Skip to content

Commit d1800db

Browse files
authored
fix: optimize_projections failure after mark joins created by EXISTS OR EXISTS (#21265)
## Which issue does this PR close? - Closes #20083. ## Rationale for this change Issue has details but main problem is mark columns from LeftMark joins leak into parent join schemas, causing `optimize_projections` optimizer to fail. ## What changes are included in this PR? Add a projection after embedded subquery decorrelation to strip mark columns, following the same pattern as `scalar_subquery_to_join`. I've seen this projection is merged in the final plan. ## Are these changes tested? Added test case for reported failure ## Are there any user-facing changes? No.
1 parent 4b8c1d9 commit d1800db

2 files changed

Lines changed: 95 additions & 14 deletions

File tree

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
8989

9090
// iterate through all exists clauses in predicate, turning each into a join
9191
let mut cur_input = Arc::unwrap_or_clone(filter.input);
92+
let original_schema = cur_input.schema().columns();
9293
for subquery_expr in with_subqueries {
9394
match extract_subquery_info(subquery_expr) {
9495
// The subquery expression is at the top level of the filter
@@ -115,6 +116,13 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
115116
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
116117
cur_input = LogicalPlan::Filter(new_filter);
117118
}
119+
120+
if cur_input.schema().fields().len() != original_schema.len() {
121+
cur_input = LogicalPlanBuilder::from(cur_input)
122+
.project(original_schema.into_iter().map(Expr::from))?
123+
.build()?;
124+
}
125+
118126
Ok(Transformed::yes(cur_input))
119127
}
120128

@@ -1736,13 +1744,14 @@ mod tests {
17361744
plan,
17371745
@r"
17381746
Projection: customer.c_custkey [c_custkey:Int64]
1739-
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1740-
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1741-
TableScan: customer [c_custkey:Int64, c_name:Utf8]
1742-
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1743-
Projection: orders.o_custkey [o_custkey:Int64]
1744-
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1745-
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1747+
Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1748+
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1749+
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1750+
TableScan: customer [c_custkey:Int64, c_name:Utf8]
1751+
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1752+
Projection: orders.o_custkey [o_custkey:Int64]
1753+
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1754+
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
17461755
"
17471756
)
17481757
}

datafusion/optimizer/src/optimize_projections/mod.rs

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,9 @@ fn optimize_projections(
390390
}
391391
LogicalPlan::Join(join) => {
392392
let left_len = join.left.schema().fields().len();
393+
let right_len = join.right.schema().fields().len();
393394
let (left_req_indices, right_req_indices) =
394-
split_join_requirements(left_len, indices, &join.join_type);
395+
split_join_requirements(left_len, right_len, indices, &join.join_type);
395396
let left_indices =
396397
left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
397398
let right_indices =
@@ -751,6 +752,7 @@ fn outer_columns_helper_multi<'a, 'b>(
751752
/// # Parameters
752753
///
753754
/// * `left_len` - The length of the left child.
755+
/// * `right_len` - The length of the right child.
754756
/// * `indices` - A slice of requirement indices.
755757
/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`).
756758
///
@@ -762,21 +764,29 @@ fn outer_columns_helper_multi<'a, 'b>(
762764
/// adjusted based on the join type.
763765
fn split_join_requirements(
764766
left_len: usize,
767+
right_len: usize,
765768
indices: RequiredIndices,
766769
join_type: &JoinType,
767770
) -> (RequiredIndices, RequiredIndices) {
768771
match join_type {
769772
// In these cases requirements are split between left/right children:
770-
JoinType::Inner
771-
| JoinType::Left
772-
| JoinType::Right
773-
| JoinType::Full
774-
| JoinType::LeftMark
775-
| JoinType::RightMark => {
773+
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
776774
// Decrease right side indices by `left_len` so that they point to valid
777775
// positions within the right child:
778776
indices.split_off(left_len)
779777
}
778+
JoinType::LeftMark => {
779+
// LeftMark output: [left_cols(0..left_len), mark]
780+
// The mark column is synthetic (produced by the join itself),
781+
// so discard it and route only to the left child.
782+
let (left_indices, _mark) = indices.split_off(left_len);
783+
(left_indices, RequiredIndices::new())
784+
}
785+
JoinType::RightMark => {
786+
// Same as LeftMark, but for the right child.
787+
let (right_indices, _mark) = indices.split_off(right_len);
788+
(RequiredIndices::new(), right_indices)
789+
}
780790
// All requirements can be re-routed to left child directly.
781791
JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
782792
// All requirements can be re-routed to right side directly.
@@ -2390,6 +2400,68 @@ mod tests {
23902400
)
23912401
}
23922402

2403+
// Regression test for https://github.com/apache/datafusion/issues/20083
2404+
// Optimizer must not fail when LeftMark joins from EXISTS OR EXISTS
2405+
// feed into a Left join.
2406+
#[test]
2407+
fn optimize_projections_exists_or_exists_with_outer_join() -> Result<()> {
2408+
use datafusion_expr::utils::disjunction;
2409+
use datafusion_expr::{exists, out_ref_col};
2410+
2411+
let table_a = test_table_scan_with_name("a")?;
2412+
let table_b = test_table_scan_with_name("b")?;
2413+
2414+
let sq_a = Arc::new(
2415+
LogicalPlanBuilder::from(test_table_scan_with_name("sq_a")?)
2416+
.filter(col("sq_a.a").eq(out_ref_col(DataType::UInt32, "a.a")))?
2417+
.project(vec![lit(1)])?
2418+
.build()?,
2419+
);
2420+
2421+
let sq_b = Arc::new(
2422+
LogicalPlanBuilder::from(test_table_scan_with_name("sq_b")?)
2423+
.filter(col("sq_b.b").eq(out_ref_col(DataType::UInt32, "a.b")))?
2424+
.project(vec![lit(1)])?
2425+
.build()?,
2426+
);
2427+
2428+
let plan = LogicalPlanBuilder::from(table_a)
2429+
.filter(disjunction(vec![exists(sq_a), exists(sq_b)]).unwrap())?
2430+
.join(table_b, JoinType::Left, (vec!["a"], vec!["a"]), None)?
2431+
.build()?;
2432+
2433+
let optimizer = Optimizer::new();
2434+
let config = OptimizerContext::new();
2435+
optimizer.optimize(plan, &config, observe)?;
2436+
2437+
Ok(())
2438+
}
2439+
2440+
#[test]
2441+
fn optimize_projections_left_mark_join_with_projection() -> Result<()> {
2442+
let table_a = test_table_scan_with_name("a")?;
2443+
let table_b = test_table_scan_with_name("b")?;
2444+
let table_c = test_table_scan_with_name("c")?;
2445+
2446+
let plan = LogicalPlanBuilder::from(table_a)
2447+
.join(table_b, JoinType::LeftMark, (vec!["a"], vec!["a"]), None)?
2448+
.project(vec![col("a.a"), col("a.b"), col("a.c")])?
2449+
.join(table_c, JoinType::Left, (vec!["a"], vec!["a"]), None)?
2450+
.build()?;
2451+
2452+
assert_optimized_plan_equal!(
2453+
plan,
2454+
@r"
2455+
Left Join: a.a = c.a
2456+
Projection: a.a, a.b, a.c
2457+
LeftMark Join: a.a = b.a
2458+
TableScan: a projection=[a, b, c]
2459+
TableScan: b projection=[a]
2460+
TableScan: c projection=[a, b, c]
2461+
"
2462+
)
2463+
}
2464+
23932465
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
23942466

23952467
fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {

0 commit comments

Comments
 (0)