1515// specific language governing permissions and limitations
1616// under the License.
1717
18- //! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause
18+ //! [`EliminateGroupByConstant`] removes constant and functionally redundant
19+ //! expressions from `GROUP BY` clause
1920use crate :: optimizer:: ApplyOrder ;
2021use crate :: { OptimizerConfig , OptimizerRule } ;
2122
23+ use std:: collections:: HashSet ;
24+
2225use datafusion_common:: Result ;
2326use datafusion_common:: tree_node:: Transformed ;
2427use datafusion_expr:: { Aggregate , Expr , LogicalPlan , LogicalPlanBuilder , Volatility } ;
@@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant {
4750 ) -> Result < Transformed < LogicalPlan > > {
4851 match plan {
4952 LogicalPlan :: Aggregate ( aggregate) => {
50- let ( const_group_expr, nonconst_group_expr) : ( Vec < _ > , Vec < _ > ) = aggregate
53+ // Collect bare column references in GROUP BY
54+ let group_by_columns: HashSet < & datafusion_common:: Column > = aggregate
5155 . group_expr
5256 . iter ( )
53- . partition ( |expr| is_constant_expression ( expr) ) ;
54-
55- // If no constant expressions found (nothing to optimize) or
56- // constant expression is the only expression in aggregate,
57- // optimization is skipped
58- if const_group_expr. is_empty ( )
59- || ( !const_group_expr. is_empty ( )
60- && nonconst_group_expr. is_empty ( )
61- && aggregate. aggr_expr . is_empty ( ) )
57+ . filter_map ( |expr| match expr {
58+ Expr :: Column ( c) => Some ( c) ,
59+ _ => None ,
60+ } )
61+ . collect ( ) ;
62+
63+ let ( redundant, required) : ( Vec < _ > , Vec < _ > ) = aggregate
64+ . group_expr
65+ . iter ( )
66+ . partition ( |expr| is_redundant_group_expr ( expr, & group_by_columns) ) ;
67+
68+ if redundant. is_empty ( )
69+ || ( required. is_empty ( ) && aggregate. aggr_expr . is_empty ( ) )
6270 {
6371 return Ok ( Transformed :: no ( LogicalPlan :: Aggregate ( aggregate) ) ) ;
6472 }
6573
6674 let simplified_aggregate = LogicalPlan :: Aggregate ( Aggregate :: try_new (
6775 aggregate. input ,
68- nonconst_group_expr . into_iter ( ) . cloned ( ) . collect ( ) ,
76+ required . into_iter ( ) . cloned ( ) . collect ( ) ,
6977 aggregate. aggr_expr . clone ( ) ,
7078 ) ?) ;
7179
@@ -91,23 +99,47 @@ impl OptimizerRule for EliminateGroupByConstant {
9199 }
92100}
93101
94- /// Checks if expression is constant, and can be eliminated from group by.
95- ///
96- /// Intended to be used only within this rule, helper function, which heavily
97- /// relies on `SimplifyExpressions` result.
98- fn is_constant_expression ( expr : & Expr ) -> bool {
102+ /// Checks if a GROUP BY expression is redundant (can be removed without
103+ /// changing grouping semantics). An expression is redundant if it is a
104+ /// deterministic function of constants and columns already present as bare
105+ /// column references in the GROUP BY.
106+ fn is_redundant_group_expr (
107+ expr : & Expr ,
108+ group_by_columns : & HashSet < & datafusion_common:: Column > ,
109+ ) -> bool {
110+ // Bare column references are never redundant - they define the grouping
111+ if matches ! ( expr, Expr :: Column ( _) ) {
112+ return false ;
113+ }
114+ is_deterministic_of ( expr, group_by_columns)
115+ }
116+
117+ /// Returns true if `expr` is a deterministic expression whose only column
118+ /// references are contained in `known_columns`.
119+ fn is_deterministic_of (
120+ expr : & Expr ,
121+ known_columns : & HashSet < & datafusion_common:: Column > ,
122+ ) -> bool {
99123 match expr {
100- Expr :: Alias ( e) => is_constant_expression ( & e. expr ) ,
124+ Expr :: Alias ( e) => is_deterministic_of ( & e. expr , known_columns) ,
125+ Expr :: Column ( c) => known_columns. contains ( c) ,
126+ Expr :: Literal ( _, _) => true ,
101127 Expr :: BinaryExpr ( e) => {
102- is_constant_expression ( & e. left ) && is_constant_expression ( & e. right )
128+ is_deterministic_of ( & e. left , known_columns)
129+ && is_deterministic_of ( & e. right , known_columns)
103130 }
104- Expr :: Literal ( _, _) => true ,
105131 Expr :: ScalarFunction ( e) => {
106132 matches ! (
107133 e. func. signature( ) . volatility,
108134 Volatility :: Immutable | Volatility :: Stable
109- ) && e. args . iter ( ) . all ( is_constant_expression)
135+ ) && e
136+ . args
137+ . iter ( )
138+ . all ( |arg| is_deterministic_of ( arg, known_columns) )
110139 }
140+ Expr :: Cast ( e) => is_deterministic_of ( & e. expr , known_columns) ,
141+ Expr :: TryCast ( e) => is_deterministic_of ( & e. expr , known_columns) ,
142+ Expr :: Negative ( e) => is_deterministic_of ( e, known_columns) ,
111143 _ => false ,
112144 }
113145}
@@ -268,6 +300,43 @@ mod tests {
268300 " )
269301 }
270302
303+ #[ test]
304+ fn test_eliminate_deterministic_expr_of_group_by_column ( ) -> Result < ( ) > {
305+ let scan = test_table_scan ( ) ?;
306+ // GROUP BY a, a - 1, a - 2, a - 3 -> GROUP BY a
307+ let plan = LogicalPlanBuilder :: from ( scan)
308+ . aggregate (
309+ vec ! [
310+ col( "a" ) ,
311+ col( "a" ) - lit( 1u32 ) ,
312+ col( "a" ) - lit( 2u32 ) ,
313+ col( "a" ) - lit( 3u32 ) ,
314+ ] ,
315+ vec ! [ count( col( "c" ) ) ] ,
316+ ) ?
317+ . build ( ) ?;
318+
319+ assert_optimized_plan_equal ! ( plan, @r"
320+ Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c)
321+ Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
322+ TableScan: test
323+ " )
324+ }
325+
326+ #[ test]
327+ fn test_no_eliminate_independent_columns ( ) -> Result < ( ) > {
328+ // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column)
329+ let scan = test_table_scan ( ) ?;
330+ let plan = LogicalPlanBuilder :: from ( scan)
331+ . aggregate ( vec ! [ col( "a" ) , col( "b" ) - lit( 1u32 ) ] , vec ! [ count( col( "c" ) ) ] ) ?
332+ . build ( ) ?;
333+
334+ assert_optimized_plan_equal ! ( plan, @r"
335+ Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]]
336+ TableScan: test
337+ " )
338+ }
339+
271340 #[ test]
272341 fn test_no_op_volatile_scalar_fn_with_constant_arg ( ) -> Result < ( ) > {
273342 let udf = ScalarUDF :: new_from_impl ( ScalarUDFMock :: new_with_volatility (
0 commit comments