2020use std:: sync:: Arc ;
2121
2222use datafusion_common:: tree_node:: { Transformed , TreeNode } ;
23- use datafusion_common:: { DFSchema , DFSchemaRef , DataFusionError , Result } ;
23+ use datafusion_common:: { Column , DFSchema , DFSchemaRef , DataFusionError , Result } ;
2424use datafusion_expr:: Expr ;
25- use datafusion_expr:: logical_plan:: LogicalPlan ;
25+ use datafusion_expr:: logical_plan:: { Aggregate , LogicalPlan , Projection } ;
2626use datafusion_expr:: simplify:: SimplifyContext ;
27- use datafusion_expr:: utils:: merge_schema;
27+ use datafusion_expr:: utils:: {
28+ columnize_expr, find_aggregate_exprs, grouping_set_to_exprlist, merge_schema,
29+ } ;
2830
2931use crate :: optimizer:: ApplyOrder ;
3032use crate :: utils:: NamePreserver ;
@@ -130,13 +132,17 @@ impl SimplifyExpressions {
130132 ) )
131133 } ;
132134
133- plan. map_expressions ( |expr| {
135+ let transformed = plan. map_expressions ( |expr| {
134136 // Preserve the aliasing of grouping sets.
135137 if let Expr :: GroupingSet ( _) = & expr {
136138 expr. map_children ( & mut rewrite_expr)
137139 } else {
138140 rewrite_expr ( expr)
139141 }
142+ } ) ?;
143+
144+ transformed. transform_data ( |plan| {
145+ rewrite_aggregate_non_aggregate_aggr_expr ( plan)
140146 } )
141147 }
142148}
@@ -148,6 +154,73 @@ impl SimplifyExpressions {
148154 }
149155}
150156
157+ /// Rewrites
158+ /// `Aggregate(group_expr, aggr_expr=[non_agg_expr(sum(..), count(..), ..)])`
159+ /// into:
160+ /// `Projection(..)` over `Aggregate(group_expr, aggr_expr=[sum(..), count(..), ..])`.
161+ ///
162+ /// Aggregate planning requires each aggregate slot to be an aggregate function
163+ /// (possibly aliased). Some UDAF simplifiers can return larger expressions that
164+ /// reference multiple aggregate functions.
165+ fn rewrite_aggregate_non_aggregate_aggr_expr (
166+ plan : LogicalPlan ,
167+ ) -> Result < Transformed < LogicalPlan > > {
168+ let LogicalPlan :: Aggregate ( Aggregate {
169+ input,
170+ group_expr,
171+ aggr_expr,
172+ ..
173+ } ) = plan
174+ else {
175+ return Ok ( Transformed :: no ( plan) ) ;
176+ } ;
177+
178+ if aggr_expr. iter ( ) . all ( is_top_level_aggregate_expr) {
179+ return Ok ( Transformed :: no ( LogicalPlan :: Aggregate ( Aggregate :: try_new (
180+ input, group_expr, aggr_expr,
181+ ) ?) ) ) ;
182+ }
183+
184+ let inner_aggr_expr = find_aggregate_exprs ( aggr_expr. iter ( ) ) ;
185+ let inner_aggregate = LogicalPlan :: Aggregate ( Aggregate :: try_new (
186+ Arc :: clone ( & input) ,
187+ group_expr. clone ( ) ,
188+ inner_aggr_expr,
189+ ) ?) ;
190+ let inner_aggregate = Arc :: new ( inner_aggregate) ;
191+
192+ let mut projection_exprs = aggregate_output_exprs ( & group_expr) ?;
193+ projection_exprs. extend ( aggr_expr) ;
194+ let projection_exprs = projection_exprs
195+ . into_iter ( )
196+ . map ( |expr| columnize_expr ( expr, inner_aggregate. as_ref ( ) ) )
197+ . collect :: < Result < Vec < _ > > > ( ) ?;
198+
199+ Ok ( Transformed :: yes ( LogicalPlan :: Projection ( Projection :: try_new (
200+ projection_exprs,
201+ inner_aggregate,
202+ ) ?) ) )
203+ }
204+
205+ fn is_top_level_aggregate_expr ( expr : & Expr ) -> bool {
206+ matches ! ( expr. clone( ) . unalias_nested( ) . data, Expr :: AggregateFunction ( _) )
207+ }
208+
209+ fn aggregate_output_exprs ( group_expr : & [ Expr ] ) -> Result < Vec < Expr > > {
210+ let mut output_exprs = grouping_set_to_exprlist ( group_expr) ?
211+ . into_iter ( )
212+ . cloned ( )
213+ . collect :: < Vec < _ > > ( ) ;
214+
215+ if matches ! ( group_expr, [ Expr :: GroupingSet ( _) ] ) {
216+ output_exprs. push ( Expr :: Column ( Column :: from_name (
217+ Aggregate :: INTERNAL_GROUPING_ID ,
218+ ) ) ) ;
219+ }
220+
221+ Ok ( output_exprs)
222+ }
223+
151224#[ cfg( test) ]
152225mod tests {
153226 use std:: ops:: Not ;
@@ -158,7 +231,7 @@ mod tests {
158231 use datafusion_expr:: logical_plan:: builder:: table_scan_with_filters;
159232 use datafusion_expr:: logical_plan:: table_scan;
160233 use datafusion_expr:: * ;
161- use datafusion_functions_aggregate:: expr_fn:: { max, min} ;
234+ use datafusion_functions_aggregate:: expr_fn:: { max, min, sum } ;
162235
163236 use crate :: OptimizerContext ;
164237 use crate :: assert_optimized_plan_eq_snapshot;
@@ -258,6 +331,28 @@ mod tests {
258331 )
259332 }
260333
334+ #[ test]
335+ fn test_simplify_udaf_to_non_aggregate_expr ( ) -> Result < ( ) > {
336+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ;
337+ let table_scan = table_scan ( Some ( "test" ) , & schema, None ) ?
338+ . build ( )
339+ . expect ( "building scan" ) ;
340+
341+ let plan = LogicalPlanBuilder :: from ( table_scan)
342+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ sum( col( "a" ) + lit( 2i64 ) ) ] ) ?
343+ . build ( ) ?;
344+
345+ assert_optimized_plan_equal ! (
346+ plan,
347+ @r"
348+ Projection: sum(test.a) + Int64(2) * count(test.a) AS sum(test.a + Int64(2))
349+ Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]]
350+ TableScan: test
351+ "
352+ ) ?;
353+ Ok ( ( ) )
354+ }
355+
261356 #[ test]
262357 fn test_simplify_optimized_plan_with_or ( ) -> Result < ( ) > {
263358 let table_scan = test_table_scan ( ) ;
0 commit comments