@@ -285,55 +285,83 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285 Ok ( is_evaluate)
286286}
287287
288- fn strip_plan_wrappers ( plan : & LogicalPlan ) -> ( & LogicalPlan , bool ) {
288+ #[ derive( Clone , Copy ) ]
289+ struct JoinInputShape < ' a > {
290+ base_plan : & ' a LogicalPlan ,
291+ is_derived_relation : bool ,
292+ }
293+
294+ fn classify_join_input ( plan : & LogicalPlan ) -> JoinInputShape < ' _ > {
289295 match plan {
290296 LogicalPlan :: SubqueryAlias ( subquery_alias) => {
291- let ( plan, _) = strip_plan_wrappers ( subquery_alias. input . as_ref ( ) ) ;
292- ( plan, true )
297+ let JoinInputShape { base_plan, .. } =
298+ classify_join_input ( subquery_alias. input . as_ref ( ) ) ;
299+ JoinInputShape {
300+ base_plan,
301+ is_derived_relation : true ,
302+ }
293303 }
294304 LogicalPlan :: Projection ( projection) => {
295- let ( plan, is_derived_relation) =
296- strip_plan_wrappers ( projection. input . as_ref ( ) ) ;
297- ( plan, is_derived_relation)
305+ let shape = classify_join_input ( projection. input . as_ref ( ) ) ;
306+ JoinInputShape {
307+ is_derived_relation : shape. is_derived_relation ,
308+ ..shape
309+ }
298310 }
299- _ => ( plan, false ) ,
311+ _ => JoinInputShape {
312+ base_plan : plan,
313+ is_derived_relation : false ,
314+ } ,
300315 }
301316}
302317
303- fn is_scalar_aggregate_subquery ( plan : & LogicalPlan ) -> bool {
318+ fn is_scalar_aggregate_subquery ( shape : JoinInputShape < ' _ > ) -> bool {
304319 matches ! (
305- strip_plan_wrappers ( plan ) . 0 ,
320+ shape . base_plan ,
306321 LogicalPlan :: Aggregate ( aggregate) if aggregate. group_expr. is_empty( )
307322 )
308323}
309324
310- fn is_derived_relation ( plan : & LogicalPlan ) -> bool {
311- strip_plan_wrappers ( plan) . 1
312- }
313-
314325fn is_scalar_subquery_cross_join ( join : & Join ) -> bool {
326+ let left_shape = classify_join_input ( join. left . as_ref ( ) ) ;
327+ let right_shape = classify_join_input ( join. right . as_ref ( ) ) ;
315328 join. on . is_empty ( )
316329 && join. filter . is_none ( )
317- && ( ( is_scalar_aggregate_subquery ( join. left . as_ref ( ) )
318- && is_derived_relation ( join. right . as_ref ( ) ) )
319- || ( is_scalar_aggregate_subquery ( join. right . as_ref ( ) )
320- && is_derived_relation ( join. left . as_ref ( ) ) ) )
330+ && ( ( is_scalar_aggregate_subquery ( left_shape) && right_shape. is_derived_relation )
331+ || ( is_scalar_aggregate_subquery ( right_shape)
332+ && left_shape. is_derived_relation ) )
321333}
322334
323335// Keep post-join filters above certain scalar-subquery cross joins to preserve
324336// behavior for the window-over-scalar-subquery regression shape.
325337fn should_keep_filter_above_scalar_subquery_cross_join (
326- join : & Join ,
338+ mut checker : ColumnChecker < ' _ > ,
327339 predicate : & Expr ,
328340) -> bool {
329- if !is_scalar_subquery_cross_join ( join) {
330- return false ;
331- }
332-
333- let mut checker = ColumnChecker :: new ( join. left . schema ( ) , join. right . schema ( ) ) ;
334341 !checker. is_left_only ( predicate) && !checker. is_right_only ( predicate)
335342}
336343
344+ enum PredicateDestination {
345+ Left ,
346+ Right ,
347+ Keep ,
348+ }
349+
350+ fn classify_predicate_destination (
351+ checker : & mut ColumnChecker < ' _ > ,
352+ predicate : & Expr ,
353+ allow_left : bool ,
354+ allow_right : bool ,
355+ ) -> PredicateDestination {
356+ if allow_left && checker. is_left_only ( predicate) {
357+ PredicateDestination :: Left
358+ } else if allow_right && checker. is_right_only ( predicate) {
359+ PredicateDestination :: Right
360+ } else {
361+ PredicateDestination :: Keep
362+ }
363+ }
364+
337365/// examine OR clause to see if any useful clauses can be extracted and push down.
338366/// extract at least one qual from each sub clauses of OR clause, then form the quals
339367/// to new OR clause as predicate.
@@ -475,29 +503,44 @@ fn push_down_all_join(
475503 let mut keep_predicates = vec ! [ ] ;
476504 let mut join_conditions = vec ! [ ] ;
477505 let mut checker = ColumnChecker :: new ( left_schema, right_schema) ;
506+ let keep_mixed_scalar_subquery_filters =
507+ is_inner_join && is_scalar_subquery_cross_join ( & join) ;
478508 for predicate in predicates {
479- if left_preserved && checker. is_left_only ( & predicate) {
480- left_push. push ( predicate) ;
481- } else if right_preserved && checker. is_right_only ( & predicate) {
482- right_push. push ( predicate) ;
483- } else if is_inner_join
484- && !should_keep_filter_above_scalar_subquery_cross_join ( & join, & predicate)
485- && can_evaluate_as_join_condition ( & predicate) ?
486- {
487- // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
488- // and convert to the join on condition
489- join_conditions. push ( predicate) ;
490- } else {
491- keep_predicates. push ( predicate) ;
509+ match classify_predicate_destination (
510+ & mut checker,
511+ & predicate,
512+ left_preserved,
513+ right_preserved,
514+ ) {
515+ PredicateDestination :: Left => left_push. push ( predicate) ,
516+ PredicateDestination :: Right => right_push. push ( predicate) ,
517+ PredicateDestination :: Keep => {
518+ let should_keep_above_join = keep_mixed_scalar_subquery_filters
519+ && should_keep_filter_above_scalar_subquery_cross_join (
520+ ColumnChecker :: new ( left_schema, right_schema) ,
521+ & predicate,
522+ ) ;
523+
524+ if is_inner_join
525+ && !should_keep_above_join
526+ && can_evaluate_as_join_condition ( & predicate) ?
527+ {
528+ // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
529+ // and convert to the join on condition
530+ join_conditions. push ( predicate) ;
531+ } else {
532+ keep_predicates. push ( predicate) ;
533+ }
534+ }
492535 }
493536 }
494537
495538 // Push predicates inferred from the join expression
496539 for predicate in inferred_join_predicates {
497- if checker . is_left_only ( & predicate) {
498- left_push. push ( predicate) ;
499- } else if checker . is_right_only ( & predicate) {
500- right_push . push ( predicate ) ;
540+ match classify_predicate_destination ( & mut checker , & predicate, true , true ) {
541+ PredicateDestination :: Left => left_push. push ( predicate) ,
542+ PredicateDestination :: Right => right_push . push ( predicate) ,
543+ PredicateDestination :: Keep => { }
501544 }
502545 }
503546
@@ -506,12 +549,15 @@ fn push_down_all_join(
506549
507550 if !on_filter. is_empty ( ) {
508551 for on in on_filter {
509- if on_left_preserved && checker. is_left_only ( & on) {
510- left_push. push ( on)
511- } else if on_right_preserved && checker. is_right_only ( & on) {
512- right_push. push ( on)
513- } else {
514- on_filter_join_conditions. push ( on)
552+ match classify_predicate_destination (
553+ & mut checker,
554+ & on,
555+ on_left_preserved,
556+ on_right_preserved,
557+ ) {
558+ PredicateDestination :: Left => left_push. push ( on) ,
559+ PredicateDestination :: Right => right_push. push ( on) ,
560+ PredicateDestination :: Keep => on_filter_join_conditions. push ( on) ,
515561 }
516562 }
517563 }
@@ -776,35 +822,30 @@ fn infer_join_predicates_impl<
776822) -> Result < ( ) > {
777823 for predicate in input_predicates {
778824 let column_refs = predicate. column_refs ( ) ;
779- let join_col_replacements: Vec < _ > = column_refs
825+ let mut saw_non_replaceable_ref = false ;
826+ let join_cols_to_replace = column_refs
780827 . iter ( )
781828 . filter_map ( |& col| {
782- join_col_keys. iter ( ) . find_map ( |( l, r) | {
829+ let replacement = join_col_keys. iter ( ) . find_map ( |( l, r) | {
783830 if ENABLE_LEFT_TO_RIGHT && col == * l {
784831 Some ( ( col, * r) )
785832 } else if ENABLE_RIGHT_TO_LEFT && col == * r {
786833 Some ( ( col, * l) )
787834 } else {
788835 None
789836 }
790- } )
837+ } ) ;
838+ saw_non_replaceable_ref |= replacement. is_none ( ) ;
839+ replacement
791840 } )
792- . collect ( ) ;
793-
794- if join_col_replacements. is_empty ( ) {
795- continue ;
796- }
841+ . collect :: < HashMap < _ , _ > > ( ) ;
797842
798- // For non-inner joins, predicates that reference any non-replaceable
799- // columns cannot be inferred on the other side. Skip the null-restriction
800- // helper entirely in that common mixed-reference case.
801- if !inferred_predicates. is_inner_join
802- && join_col_replacements. len ( ) != column_refs. len ( )
843+ if join_cols_to_replace. is_empty ( )
844+ || ( !inferred_predicates. is_inner_join && saw_non_replaceable_ref)
803845 {
804846 continue ;
805847 }
806848
807- let join_cols_to_replace = join_col_replacements. into_iter ( ) . collect ( ) ;
808849 inferred_predicates
809850 . try_build_predicate ( predicate. clone ( ) , & join_cols_to_replace) ?;
810851 }
0 commit comments