@@ -35,7 +35,7 @@ use datafusion_common::{
3535} ;
3636use datafusion_expr:: expr:: WindowFunction ;
3737use datafusion_expr:: expr_rewriter:: replace_col;
38- use datafusion_expr:: logical_plan:: { Aggregate , Join , JoinType , LogicalPlan , TableScan , Union } ;
38+ use datafusion_expr:: logical_plan:: { Join , JoinType , LogicalPlan , TableScan , Union } ;
3939use datafusion_expr:: utils:: {
4040 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
4141} ;
@@ -431,10 +431,7 @@ fn push_down_all_join(
431431 left_push. push ( predicate) ;
432432 } else if right_preserved && checker. is_right_only ( & predicate) {
433433 right_push. push ( predicate) ;
434- } else if is_inner_join
435- && can_promote_post_join_filter_to_join_condition ( & join)
436- && can_evaluate_as_join_condition ( & predicate) ?
437- {
434+ } else if is_inner_join && can_evaluate_as_join_condition ( & predicate) ? {
438435 // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
439436 // and convert to the join on condition
440437 join_conditions. push ( predicate) ;
@@ -515,70 +512,6 @@ fn push_down_all_join(
515512 Ok ( Transformed :: yes ( plan) )
516513}
517514
518- /// Returns true when post-join filters are allowed to be promoted to join conditions.
519- ///
520- /// Protection is necessary for scalar-side joins and cross joins to avoid incorrectly
521- /// rewriting a post-join filter into the join condition when one side may disappear
522- /// entirely, even though `max_rows() == Some(1)`.
523- ///
524- /// - `join.on` non-empty means existing join predicates already exist; promotion is safe.
525- /// - if a scalar side is a scalar-subquery-shaped input that is provably exactly one
526- /// row, promotion is safe.
527- /// - otherwise, keep the filter above the join.
528- fn can_promote_post_join_filter_to_join_condition ( join : & Join ) -> bool {
529- !join. on . is_empty ( )
530- || scalar_side_can_promote_post_join_filter ( join. left . as_ref ( ) )
531- && scalar_side_can_promote_post_join_filter ( join. right . as_ref ( ) )
532- }
533-
534- /// Returns true when a non-scalar side is unrestricted, or when a scalar side is
535- /// a safe exact-one-row scalar-subquery shape.
536- fn scalar_side_can_promote_post_join_filter ( plan : & LogicalPlan ) -> bool {
537- !is_scalar_side ( plan) || is_safe_scalar_subquery_side ( plan)
538- }
539-
540- fn is_scalar_side ( plan : & LogicalPlan ) -> bool {
541- matches ! ( plan. max_rows( ) , Some ( 1 ) )
542- }
543-
544- /// Returns true for the scalar-subquery-shaped inputs where post-join filter
545- /// promotion should remain legal.
546- fn is_safe_scalar_subquery_side ( plan : & LogicalPlan ) -> bool {
547- match plan {
548- LogicalPlan :: Projection ( projection) => {
549- is_safe_scalar_subquery_side ( projection. input . as_ref ( ) )
550- }
551- LogicalPlan :: Repartition ( repartition) => {
552- is_safe_scalar_subquery_side ( repartition. input . as_ref ( ) )
553- }
554- LogicalPlan :: Sort ( sort) => is_safe_scalar_subquery_side ( sort. input . as_ref ( ) ) ,
555- LogicalPlan :: SubqueryAlias ( subquery_alias) => {
556- returns_exactly_one_row ( subquery_alias. input . as_ref ( ) )
557- }
558- _ => false ,
559- }
560- }
561-
562- /// Returns true when the plan is guaranteed to produce exactly one row.
563- fn returns_exactly_one_row ( plan : & LogicalPlan ) -> bool {
564- match plan {
565- LogicalPlan :: Projection ( projection) => returns_exactly_one_row ( projection. input . as_ref ( ) ) ,
566- LogicalPlan :: SubqueryAlias ( subquery_alias) => {
567- returns_exactly_one_row ( subquery_alias. input . as_ref ( ) )
568- }
569- LogicalPlan :: Repartition ( repartition) => {
570- returns_exactly_one_row ( repartition. input . as_ref ( ) )
571- }
572- LogicalPlan :: Sort ( sort) => returns_exactly_one_row ( sort. input . as_ref ( ) ) ,
573- LogicalPlan :: Aggregate ( Aggregate { group_expr, .. } ) => {
574- group_expr
575- . iter ( )
576- . all ( |expr| matches ! ( expr, Expr :: Literal ( _, _) ) )
577- }
578- _ => false ,
579- }
580- }
581-
582515fn push_down_join (
583516 join : Join ,
584517 parent_predicate : Option < & Expr > ,
@@ -1562,7 +1495,7 @@ mod tests {
15621495 use crate :: simplify_expressions:: SimplifyExpressions ;
15631496 use crate :: test:: udfs:: leaf_udf_expr;
15641497 use crate :: test:: * ;
1565- use datafusion_expr:: test:: function_stub:: { avg , sum} ;
1498+ use datafusion_expr:: test:: function_stub:: sum;
15661499 use insta:: assert_snapshot;
15671500
15681501 use super :: * ;
@@ -3699,124 +3632,6 @@ mod tests {
36993632 )
37003633 }
37013634
3702- #[ test]
3703- fn cross_join_with_scalar_side_keeps_post_join_filter ( ) -> Result < ( ) > {
3704- let left = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
3705- . project ( vec ! [ col( "a" ) , col( "b" ) ] ) ?
3706- . build ( ) ?;
3707- let right = LogicalPlanBuilder :: from ( test_table_scan_with_name ( "test1" ) ?)
3708- . project ( vec ! [ col( "a" ) ] ) ?
3709- . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ avg( col( "a" ) ) . alias( "avg_a" ) ] ) ?
3710- . build ( ) ?;
3711- let plan = LogicalPlanBuilder :: from ( left)
3712- . cross_join ( right) ?
3713- . filter ( col ( "test.b" ) . gt ( col ( "avg_a" ) ) ) ?
3714- . build ( ) ?;
3715-
3716- assert_optimized_plan_equal ! (
3717- plan,
3718- @r"
3719- Filter: test.b > avg_a
3720- Cross Join:
3721- Projection: test.a, test.b
3722- TableScan: test
3723- Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]]
3724- Projection: test1.a
3725- TableScan: test1
3726- "
3727- )
3728- }
3729-
3730- #[ test]
3731- fn cross_join_with_exact_one_row_subquery_promotes_post_join_filter ( ) -> Result < ( ) > {
3732- let left = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
3733- . project ( vec ! [ col( "a" ) , col( "b" ) ] ) ?
3734- . build ( ) ?;
3735- let right = LogicalPlanBuilder :: from ( test_table_scan_with_name ( "test1" ) ?)
3736- . project ( vec ! [ col( "a" ) ] ) ?
3737- . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ avg( col( "a" ) ) . alias( "avg_a" ) ] ) ?
3738- . alias ( "sq" ) ?
3739- . build ( ) ?;
3740- let plan = LogicalPlanBuilder :: from ( left)
3741- . cross_join ( right) ?
3742- . filter ( col ( "test.b" ) . gt ( col ( "sq.avg_a" ) ) ) ?
3743- . build ( ) ?;
3744-
3745- assert_optimized_plan_equal ! (
3746- plan,
3747- @r"
3748- Inner Join: Filter: test.b > sq.avg_a
3749- Projection: test.a, test.b
3750- TableScan: test
3751- SubqueryAlias: sq
3752- Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]]
3753- Projection: test1.a
3754- TableScan: test1
3755- "
3756- )
3757- }
3758-
3759- #[ test]
3760- fn cross_join_with_at_most_one_row_side_keeps_post_join_filter ( ) -> Result < ( ) > {
3761- let left = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
3762- . project ( vec ! [ col( "a" ) , col( "b" ) ] ) ?
3763- . build ( ) ?;
3764- let right = LogicalPlanBuilder :: from ( test_table_scan_with_name ( "test1" ) ?)
3765- . project ( vec ! [ col( "a" ) ] ) ?
3766- . limit ( 0 , Some ( 1 ) ) ?
3767- . alias ( "sq" ) ?
3768- . build ( ) ?;
3769- let plan = LogicalPlanBuilder :: from ( left)
3770- . cross_join ( right) ?
3771- . filter ( col ( "test.b" ) . gt ( col ( "sq.a" ) ) ) ?
3772- . build ( ) ?;
3773-
3774- assert_optimized_plan_equal ! (
3775- plan,
3776- @r"
3777- Filter: test.b > sq.a
3778- Cross Join:
3779- Projection: test.a, test.b
3780- TableScan: test
3781- SubqueryAlias: sq
3782- Limit: skip=0, fetch=1
3783- Projection: test1.a
3784- TableScan: test1
3785- "
3786- )
3787- }
3788-
3789- #[ test]
3790- fn returns_exactly_one_row_for_global_aggregate ( ) -> Result < ( ) > {
3791- let plan = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
3792- . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ avg( col( "a" ) ) ] ) ?
3793- . build ( ) ?;
3794-
3795- assert ! ( returns_exactly_one_row( & plan) ) ;
3796- Ok ( ( ) )
3797- }
3798-
3799- #[ test]
3800- fn returns_exactly_one_row_is_false_for_filtered_global_aggregate ( ) -> Result < ( ) > {
3801- let plan = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
3802- . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ avg( col( "a" ) ) ] ) ?
3803- . filter ( col ( "avg(test.a)" ) . gt ( lit ( 0i64 ) ) ) ?
3804- . build ( ) ?;
3805-
3806- assert ! ( !returns_exactly_one_row( & plan) ) ;
3807- Ok ( ( ) )
3808- }
3809-
3810- #[ test]
3811- fn returns_exactly_one_row_is_false_for_limit_fetch_one ( ) -> Result < ( ) > {
3812- let plan = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
3813- . limit ( 0 , Some ( 1 ) ) ?
3814- . build ( ) ?;
3815-
3816- assert ! ( !returns_exactly_one_row( & plan) ) ;
3817- Ok ( ( ) )
3818- }
3819-
38203635 #[ test]
38213636 fn left_semi_join ( ) -> Result < ( ) > {
38223637 let left = test_table_scan_with_name ( "test1" ) ?;
0 commit comments