Skip to content

Commit e74f5da

Browse files
committed
Reduce plan-shape sensitivity for scalar-subquery guard
Broaden derived-relation detection to include projection wrappers over derived relations. Add regression tests to cover alias/projection shape changes and ensure mixed-side filters are preserved. Implement a panic-path robustness test to confirm that eval mode resets properly, even on closure panic using catch_unwind.
1 parent e70fa3a commit e74f5da

2 files changed

Lines changed: 81 additions & 1 deletion

File tree

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,13 @@ fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool {
305305
}
306306

307307
fn is_derived_relation(plan: &LogicalPlan) -> bool {
308-
matches!(plan, LogicalPlan::SubqueryAlias(_))
308+
match plan {
309+
LogicalPlan::SubqueryAlias(_) => true,
310+
LogicalPlan::Projection(projection) => {
311+
is_derived_relation(projection.input.as_ref())
312+
}
313+
_ => false,
314+
}
309315
}
310316

311317
// Keep post-join filters above certain scalar-subquery cross joins to preserve
@@ -2517,6 +2523,60 @@ mod tests {
25172523
)
25182524
}
25192525

2526+
#[test]
2527+
fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join(
2528+
) -> Result<()> {
2529+
let left = LogicalPlanBuilder::from(test_table_scan()?)
2530+
.project(vec![col("a").alias("nation"), col("b").alias("acctbal")])?
2531+
.alias("s")?
2532+
.project(vec![col("s.nation"), col("s.acctbal")])?
2533+
.build()?;
2534+
let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?)
2535+
.project(vec![col("a").alias("acctbal")])?
2536+
.aggregate(
2537+
Vec::<Expr>::new(),
2538+
vec![avg(col("acctbal")).alias("avg_acctbal")],
2539+
)?
2540+
.alias("__scalar_sq_1")?
2541+
.build()?;
2542+
2543+
let window = Expr::from(WindowFunction::new(
2544+
WindowFunctionDefinition::WindowUDF(
2545+
datafusion_functions_window::row_number::row_number_udwf(),
2546+
),
2547+
vec![],
2548+
))
2549+
.partition_by(vec![col("s.nation")])
2550+
.order_by(vec![col("s.acctbal").sort(false, true)])
2551+
.build()
2552+
.unwrap();
2553+
2554+
let plan = LogicalPlanBuilder::from(left)
2555+
.cross_join(right)?
2556+
.filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))?
2557+
.project(vec![col("s.nation"), col("s.acctbal")])?
2558+
.window(vec![window])?
2559+
.build()?;
2560+
2561+
assert_optimized_plan_equal!(
2562+
plan,
2563+
@r"
2564+
WindowAggr: windowExpr=[[row_number() PARTITION BY [s.nation] ORDER BY [s.acctbal DESC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
2565+
Projection: s.nation, s.acctbal
2566+
Filter: s.acctbal > __scalar_sq_1.avg_acctbal
2567+
Cross Join:
2568+
Projection: s.nation, s.acctbal
2569+
SubqueryAlias: s
2570+
Projection: test.a AS nation, test.b AS acctbal
2571+
TableScan: test
2572+
SubqueryAlias: __scalar_sq_1
2573+
Aggregate: groupBy=[[]], aggr=[[avg(acctbal) AS avg_acctbal]]
2574+
Projection: test1.a AS acctbal
2575+
TableScan: test1
2576+
"
2577+
)
2578+
}
2579+
25202580
#[test]
25212581
fn cross_join_builder_uses_inner_join_with_no_join_keys() -> Result<()> {
25222582
let plan = LogicalPlanBuilder::from(test_table_scan()?)

datafusion/optimizer/src/utils.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
245245
#[cfg(test)]
246246
mod tests {
247247
use super::*;
248+
use std::panic::{AssertUnwindSafe, catch_unwind};
249+
248250
use datafusion_expr::{
249251
Operator, binary_expr, case, col, in_list, is_null, lit, when,
250252
};
@@ -512,4 +514,22 @@ mod tests {
512514

513515
Ok(())
514516
}
517+
518+
#[test]
519+
fn null_restriction_eval_mode_guard_restores_on_panic() {
520+
set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto);
521+
522+
let result = catch_unwind(AssertUnwindSafe(|| {
523+
with_null_restriction_eval_mode_for_test(
524+
NullRestrictionEvalMode::AuthoritativeOnly,
525+
|| panic!("intentional panic to verify test mode reset"),
526+
)
527+
}));
528+
529+
assert!(result.is_err());
530+
assert_eq!(
531+
null_restriction_eval_mode(),
532+
NullRestrictionEvalMode::Auto
533+
);
534+
}
515535
}

0 commit comments

Comments
 (0)