Skip to content

Commit 106e963

Browse files
committed
Refactor join input handling and null evaluation
Consolidate repeated join-input wrapper inspections into a single JoinInputShape classifier. Hoist scalar-subquery cross-join shape check out of the predicate loop and unify repeated left/right predicate bucketing. Remove temporary Vec in join-column replacement inference and narrow test-only null-restriction mode support into its own helper module. Share column-subset check path and extract helper for authoritative null-evaluation results. Reduce repetition in syntactic null-restriction evaluator by factoring strict-null-preserving unary cases.
1 parent 18b7541 commit 106e963

3 files changed

Lines changed: 198 additions & 151 deletions

File tree

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 101 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -285,55 +285,83 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285
Ok(is_evaluate)
286286
}
287287

288-
fn strip_plan_wrappers(plan: &LogicalPlan) -> (&LogicalPlan, bool) {
288+
#[derive(Clone, Copy)]
289+
struct JoinInputShape<'a> {
290+
base_plan: &'a LogicalPlan,
291+
is_derived_relation: bool,
292+
}
293+
294+
fn classify_join_input(plan: &LogicalPlan) -> JoinInputShape<'_> {
289295
match plan {
290296
LogicalPlan::SubqueryAlias(subquery_alias) => {
291-
let (plan, _) = strip_plan_wrappers(subquery_alias.input.as_ref());
292-
(plan, true)
297+
let JoinInputShape { base_plan, .. } =
298+
classify_join_input(subquery_alias.input.as_ref());
299+
JoinInputShape {
300+
base_plan,
301+
is_derived_relation: true,
302+
}
293303
}
294304
LogicalPlan::Projection(projection) => {
295-
let (plan, is_derived_relation) =
296-
strip_plan_wrappers(projection.input.as_ref());
297-
(plan, is_derived_relation)
305+
let shape = classify_join_input(projection.input.as_ref());
306+
JoinInputShape {
307+
is_derived_relation: shape.is_derived_relation,
308+
..shape
309+
}
298310
}
299-
_ => (plan, false),
311+
_ => JoinInputShape {
312+
base_plan: plan,
313+
is_derived_relation: false,
314+
},
300315
}
301316
}
302317

303-
fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool {
318+
fn is_scalar_aggregate_subquery(shape: JoinInputShape<'_>) -> bool {
304319
matches!(
305-
strip_plan_wrappers(plan).0,
320+
shape.base_plan,
306321
LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty()
307322
)
308323
}
309324

310-
fn is_derived_relation(plan: &LogicalPlan) -> bool {
311-
strip_plan_wrappers(plan).1
312-
}
313-
314325
fn is_scalar_subquery_cross_join(join: &Join) -> bool {
326+
let left_shape = classify_join_input(join.left.as_ref());
327+
let right_shape = classify_join_input(join.right.as_ref());
315328
join.on.is_empty()
316329
&& join.filter.is_none()
317-
&& ((is_scalar_aggregate_subquery(join.left.as_ref())
318-
&& is_derived_relation(join.right.as_ref()))
319-
|| (is_scalar_aggregate_subquery(join.right.as_ref())
320-
&& is_derived_relation(join.left.as_ref())))
330+
&& ((is_scalar_aggregate_subquery(left_shape) && right_shape.is_derived_relation)
331+
|| (is_scalar_aggregate_subquery(right_shape)
332+
&& left_shape.is_derived_relation))
321333
}
322334

323335
// Keep post-join filters above certain scalar-subquery cross joins to preserve
324336
// behavior for the window-over-scalar-subquery regression shape.
325337
fn should_keep_filter_above_scalar_subquery_cross_join(
326-
join: &Join,
338+
mut checker: ColumnChecker<'_>,
327339
predicate: &Expr,
328340
) -> bool {
329-
if !is_scalar_subquery_cross_join(join) {
330-
return false;
331-
}
332-
333-
let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema());
334341
!checker.is_left_only(predicate) && !checker.is_right_only(predicate)
335342
}
336343

344+
enum PredicateDestination {
345+
Left,
346+
Right,
347+
Keep,
348+
}
349+
350+
fn classify_predicate_destination(
351+
checker: &mut ColumnChecker<'_>,
352+
predicate: &Expr,
353+
allow_left: bool,
354+
allow_right: bool,
355+
) -> PredicateDestination {
356+
if allow_left && checker.is_left_only(predicate) {
357+
PredicateDestination::Left
358+
} else if allow_right && checker.is_right_only(predicate) {
359+
PredicateDestination::Right
360+
} else {
361+
PredicateDestination::Keep
362+
}
363+
}
364+
337365
/// examine OR clause to see if any useful clauses can be extracted and push down.
338366
/// extract at least one qual from each sub clauses of OR clause, then form the quals
339367
/// to new OR clause as predicate.
@@ -475,29 +503,44 @@ fn push_down_all_join(
475503
let mut keep_predicates = vec![];
476504
let mut join_conditions = vec![];
477505
let mut checker = ColumnChecker::new(left_schema, right_schema);
506+
let keep_mixed_scalar_subquery_filters =
507+
is_inner_join && is_scalar_subquery_cross_join(&join);
478508
for predicate in predicates {
479-
if left_preserved && checker.is_left_only(&predicate) {
480-
left_push.push(predicate);
481-
} else if right_preserved && checker.is_right_only(&predicate) {
482-
right_push.push(predicate);
483-
} else if is_inner_join
484-
&& !should_keep_filter_above_scalar_subquery_cross_join(&join, &predicate)
485-
&& can_evaluate_as_join_condition(&predicate)?
486-
{
487-
// Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
488-
// and convert to the join on condition
489-
join_conditions.push(predicate);
490-
} else {
491-
keep_predicates.push(predicate);
509+
match classify_predicate_destination(
510+
&mut checker,
511+
&predicate,
512+
left_preserved,
513+
right_preserved,
514+
) {
515+
PredicateDestination::Left => left_push.push(predicate),
516+
PredicateDestination::Right => right_push.push(predicate),
517+
PredicateDestination::Keep => {
518+
let should_keep_above_join = keep_mixed_scalar_subquery_filters
519+
&& should_keep_filter_above_scalar_subquery_cross_join(
520+
ColumnChecker::new(left_schema, right_schema),
521+
&predicate,
522+
);
523+
524+
if is_inner_join
525+
&& !should_keep_above_join
526+
&& can_evaluate_as_join_condition(&predicate)?
527+
{
528+
// Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
529+
// and convert to the join on condition
530+
join_conditions.push(predicate);
531+
} else {
532+
keep_predicates.push(predicate);
533+
}
534+
}
492535
}
493536
}
494537

495538
// Push predicates inferred from the join expression
496539
for predicate in inferred_join_predicates {
497-
if checker.is_left_only(&predicate) {
498-
left_push.push(predicate);
499-
} else if checker.is_right_only(&predicate) {
500-
right_push.push(predicate);
540+
match classify_predicate_destination(&mut checker, &predicate, true, true) {
541+
PredicateDestination::Left => left_push.push(predicate),
542+
PredicateDestination::Right => right_push.push(predicate),
543+
PredicateDestination::Keep => {}
501544
}
502545
}
503546

@@ -506,12 +549,15 @@ fn push_down_all_join(
506549

507550
if !on_filter.is_empty() {
508551
for on in on_filter {
509-
if on_left_preserved && checker.is_left_only(&on) {
510-
left_push.push(on)
511-
} else if on_right_preserved && checker.is_right_only(&on) {
512-
right_push.push(on)
513-
} else {
514-
on_filter_join_conditions.push(on)
552+
match classify_predicate_destination(
553+
&mut checker,
554+
&on,
555+
on_left_preserved,
556+
on_right_preserved,
557+
) {
558+
PredicateDestination::Left => left_push.push(on),
559+
PredicateDestination::Right => right_push.push(on),
560+
PredicateDestination::Keep => on_filter_join_conditions.push(on),
515561
}
516562
}
517563
}
@@ -776,35 +822,30 @@ fn infer_join_predicates_impl<
776822
) -> Result<()> {
777823
for predicate in input_predicates {
778824
let column_refs = predicate.column_refs();
779-
let join_col_replacements: Vec<_> = column_refs
825+
let mut saw_non_replaceable_ref = false;
826+
let join_cols_to_replace = column_refs
780827
.iter()
781828
.filter_map(|&col| {
782-
join_col_keys.iter().find_map(|(l, r)| {
829+
let replacement = join_col_keys.iter().find_map(|(l, r)| {
783830
if ENABLE_LEFT_TO_RIGHT && col == *l {
784831
Some((col, *r))
785832
} else if ENABLE_RIGHT_TO_LEFT && col == *r {
786833
Some((col, *l))
787834
} else {
788835
None
789836
}
790-
})
837+
});
838+
saw_non_replaceable_ref |= replacement.is_none();
839+
replacement
791840
})
792-
.collect();
793-
794-
if join_col_replacements.is_empty() {
795-
continue;
796-
}
841+
.collect::<HashMap<_, _>>();
797842

798-
// For non-inner joins, predicates that reference any non-replaceable
799-
// columns cannot be inferred on the other side. Skip the null-restriction
800-
// helper entirely in that common mixed-reference case.
801-
if !inferred_predicates.is_inner_join
802-
&& join_col_replacements.len() != column_refs.len()
843+
if join_cols_to_replace.is_empty()
844+
|| (!inferred_predicates.is_inner_join && saw_non_replaceable_ref)
803845
{
804846
continue;
805847
}
806848

807-
let join_cols_to_replace = join_col_replacements.into_iter().collect();
808849
inferred_predicates
809850
.try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
810851
}

0 commit comments

Comments
 (0)