@@ -32,12 +32,16 @@ use datafusion_common::types::{
3232 logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3333} ;
3434use datafusion_common:: { HashMap , Result , ScalarValue , exec_err, not_impl_err} ;
35- use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
35+ use datafusion_expr:: expr:: { AggregateFunction , AggregateFunctionParams } ;
36+ use datafusion_expr:: function:: {
37+ AccumulatorArgs , AggregateFunctionSimplification , StateFieldsArgs ,
38+ } ;
39+ use datafusion_expr:: simplify:: SimplifyContext ;
3640use datafusion_expr:: utils:: { AggregateOrderSensitivity , format_state_name} ;
3741use datafusion_expr:: {
38- Accumulator , AggregateUDFImpl , Coercion , Documentation , Expr , GroupsAccumulator ,
39- ReversedUDAF , SetMonotonicity , Signature , TypeSignature , TypeSignatureClass ,
40- Volatility ,
42+ Accumulator , AggregateUDFImpl , BinaryExpr , Coercion , Documentation , Expr ,
43+ GroupsAccumulator , Operator , ReversedUDAF , SetMonotonicity , Signature , TypeSignature ,
44+ TypeSignatureClass , Volatility ,
4145} ;
4246use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: prim_op:: PrimitiveGroupsAccumulator ;
4347use datafusion_functions_aggregate_common:: aggregate:: sum_distinct:: DistinctSumAccumulator ;
@@ -54,7 +58,7 @@ make_udaf_expr_and_func!(
5458) ;
5559
5660pub fn sum_distinct ( expr : Expr ) -> Expr {
57- Expr :: AggregateFunction ( datafusion_expr :: expr :: AggregateFunction :: new_udf (
61+ Expr :: AggregateFunction ( AggregateFunction :: new_udf (
5862 sum_udaf ( ) ,
5963 vec ! [ expr] ,
6064 true ,
@@ -346,6 +350,88 @@ impl AggregateUDFImpl for Sum {
346350 _ => SetMonotonicity :: NotMonotonic ,
347351 }
348352 }
353+
354+ /// Simplification Rules
355+ fn simplify ( & self ) -> Option < AggregateFunctionSimplification > {
356+ Some ( Box :: new ( sum_simplifier) )
357+ }
358+ }
359+
360+ /// Implement ClickBench Q29 specific optimization:
361+ /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
362+ ///
363+ /// Backstory: TODO
364+ ///
365+ fn sum_simplifier ( mut agg : AggregateFunction , _info : & SimplifyContext ) -> Result < Expr > {
366+ // Explicitly destructure to ensure we check all relevant fields
367+ let AggregateFunctionParams {
368+ args,
369+ distinct,
370+ filter,
371+ order_by,
372+ null_treatment,
373+ } = & agg. params ;
374+
375+ if * distinct
376+ || filter. is_some ( )
377+ || !order_by. is_empty ( )
378+ || null_treatment. is_some ( )
379+ || args. len ( ) != 1
380+ {
381+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
382+ }
383+
384+ // otherwise check the arguments if they are <col> <op> scalar
385+ let ( arg, lit) = match SplitResult :: new ( agg. params . args . swap_remove ( 0 ) ) {
386+ SplitResult :: Original ( expr) => {
387+ agg. params . args . push ( expr) ; // put it back
388+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
389+ }
390+ SplitResult :: Split { arg, lit } => ( arg, lit) ,
391+ } ;
392+
393+ // Rewrite to SUM(arg)
394+ agg. params . args . push ( arg. clone ( ) ) ;
395+ let sum_agg = Expr :: AggregateFunction ( agg) ;
396+
397+ // sum(arg) + scalar * COUNT(arg)
398+ Ok ( sum_agg + ( lit * crate :: count:: count ( arg) ) )
399+ }
400+
401+ /// Result of trying to split an expression into an arg and constant
402+ #[ derive( Debug , Clone ) ]
403+ enum SplitResult {
404+ /// if the expression is either of
405+ /// * `<arg> <op> <lit>`
406+ /// * `<lit> <op> <arg>`
407+ ///
408+ /// When `op` is `+`
409+ Split { arg : Expr , lit : Expr } ,
410+ /// If the expression is something else
411+ Original ( Expr ) ,
412+ }
413+
414+ impl SplitResult {
415+ fn new ( expr : Expr ) -> Self {
416+ let Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) = expr else {
417+ return Self :: Original ( expr) ;
418+ } ;
419+ if op != Operator :: Plus {
420+ return Self :: Original ( Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) ) ;
421+ }
422+
423+ match ( left. as_ref ( ) , right. as_ref ( ) ) {
424+ ( Expr :: Literal ( ..) , _) => Self :: Split {
425+ arg : * right,
426+ lit : * left,
427+ } ,
428+ ( _, Expr :: Literal ( ..) ) => Self :: Split {
429+ arg : * left,
430+ lit : * right,
431+ } ,
432+ _ => Self :: Original ( Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) ) ,
433+ }
434+ }
349435}
350436
351437/// This accumulator computes SUM incrementally
0 commit comments