@@ -32,12 +32,17 @@ use datafusion_common::types::{
3232 logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3333} ;
3434use datafusion_common:: { HashMap , Result , ScalarValue , exec_err, not_impl_err} ;
35- use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
35+ use datafusion_expr:: expr:: { AggregateFunction , AggregateFunctionParams } ;
36+ use datafusion_expr:: expr_fn:: cast;
37+ use datafusion_expr:: function:: {
38+ AccumulatorArgs , AggregateFunctionSimplification , StateFieldsArgs ,
39+ } ;
40+ use datafusion_expr:: simplify:: SimplifyContext ;
3641use datafusion_expr:: utils:: { AggregateOrderSensitivity , format_state_name} ;
3742use datafusion_expr:: {
38- Accumulator , AggregateUDFImpl , Coercion , Documentation , Expr , GroupsAccumulator ,
39- ReversedUDAF , SetMonotonicity , Signature , TypeSignature , TypeSignatureClass ,
40- Volatility ,
43+ Accumulator , AggregateUDFImpl , BinaryExpr , Coercion , Documentation , Expr ,
44+ GroupsAccumulator , Operator , ReversedUDAF , SetMonotonicity , Signature , TypeSignature ,
45+ TypeSignatureClass , Volatility ,
4146} ;
4247use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: prim_op:: PrimitiveGroupsAccumulator ;
4348use datafusion_functions_aggregate_common:: aggregate:: sum_distinct:: DistinctSumAccumulator ;
@@ -54,7 +59,7 @@ make_udaf_expr_and_func!(
5459) ;
5560
5661pub fn sum_distinct ( expr : Expr ) -> Expr {
57- Expr :: AggregateFunction ( datafusion_expr :: expr :: AggregateFunction :: new_udf (
62+ Expr :: AggregateFunction ( AggregateFunction :: new_udf (
5863 sum_udaf ( ) ,
5964 vec ! [ expr] ,
6065 true ,
@@ -346,6 +351,147 @@ impl AggregateUDFImpl for Sum {
346351 _ => SetMonotonicity :: NotMonotonic ,
347352 }
348353 }
354+
355+ /// Simplification Rules
356+ fn simplify ( & self ) -> Option < AggregateFunctionSimplification > {
357+ Some ( Box :: new ( sum_simplifier) )
358+ }
359+ }
360+
361+ /// Implement ClickBench Q29 specific optimization:
362+ /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
363+ ///
364+ /// Backstory: TODO
365+ ///
366+ fn sum_simplifier ( mut agg : AggregateFunction , info : & SimplifyContext ) -> Result < Expr > {
367+ // Explicitly destructure to ensure we check all relevant fields
368+ let AggregateFunctionParams {
369+ args,
370+ distinct,
371+ filter,
372+ order_by,
373+ null_treatment,
374+ } = & agg. params ;
375+
376+ if * distinct
377+ || filter. is_some ( )
378+ || !order_by. is_empty ( )
379+ || null_treatment. is_some ( )
380+ || args. len ( ) != 1
381+ {
382+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
383+ }
384+
385+ // otherwise check the arguments if they are <arg> + <literal>
386+ let ( arg, lit) = match SplitResult :: new ( agg. params . args [ 0 ] . clone ( ) ) {
387+ SplitResult :: Original => return Ok ( Expr :: AggregateFunction ( agg) ) ,
388+ SplitResult :: Split { arg, lit } => ( arg, lit) ,
389+ } ;
390+
391+ if !has_common_rewrite_arg ( & arg, info) {
392+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
393+ }
394+
395+ let lit_type = match & lit {
396+ Expr :: Literal ( value, _) => value. data_type ( ) ,
397+ _ => unreachable ! ( "SplitResult::Split guarantees literal side" ) ,
398+ } ;
399+ if lit_type == DataType :: Null {
400+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
401+ }
402+
403+ // Rewrite to SUM(arg)
404+ agg. params . args = vec ! [ arg. clone( ) ] ;
405+ let sum_agg = Expr :: AggregateFunction ( agg) ;
406+
407+ let count_agg = cast ( crate :: count:: count ( arg) , lit_type) ;
408+
409+ // sum(arg) + scalar * COUNT(arg)
410+ Ok ( sum_agg + ( lit * count_agg) )
411+ }
412+
413+ fn has_common_rewrite_arg ( arg : & Expr , info : & SimplifyContext ) -> bool {
414+ let Some ( aggregate_exprs) = info. aggregate_exprs ( ) else {
415+ // Only apply this rewrite in the context of an Aggregate node where
416+ // sibling aggregate expressions are known.
417+ return false ;
418+ } ;
419+
420+ aggregate_exprs
421+ . iter ( )
422+ . filter_map ( sum_rewrite_candidate_arg)
423+ . filter ( |candidate_arg| candidate_arg == arg)
424+ . take ( 2 )
425+ . count ( )
426+ > 1
427+ }
428+
429+ fn sum_rewrite_candidate_arg ( expr : & Expr ) -> Option < Expr > {
430+ let Expr :: AggregateFunction ( aggregate_fn) = expr. clone ( ) . unalias_nested ( ) . data else {
431+ return None ;
432+ } ;
433+ if !aggregate_fn. func . name ( ) . eq_ignore_ascii_case ( "sum" ) {
434+ return None ;
435+ }
436+
437+ let AggregateFunctionParams {
438+ args,
439+ distinct,
440+ filter,
441+ order_by,
442+ null_treatment,
443+ } = & aggregate_fn. params ;
444+
445+ if * distinct
446+ || filter. is_some ( )
447+ || !order_by. is_empty ( )
448+ || null_treatment. is_some ( )
449+ || args. len ( ) != 1
450+ {
451+ return None ;
452+ }
453+
454+ match SplitResult :: new ( args[ 0 ] . clone ( ) ) {
455+ SplitResult :: Split { arg, .. } => Some ( arg) ,
456+ SplitResult :: Original => None ,
457+ }
458+ }
459+
460+ /// Result of trying to split an expression into an arg and constant
461+ #[ expect( clippy:: large_enum_variant) ]
462+ #[ derive( Debug , Clone ) ]
463+ enum SplitResult {
464+ /// if the expression is either of
465+ /// * `<arg> <op> <lit>`
466+ /// * `<lit> <op> <arg>`
467+ ///
468+ /// When `op` is `+`
469+ Split { arg : Expr , lit : Expr } ,
470+ /// If the expression is something else
471+ Original ,
472+ }
473+
474+ impl SplitResult {
475+ fn new ( expr : Expr ) -> Self {
476+ let Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) = expr else {
477+ return Self :: Original ;
478+ } ;
479+ if op != Operator :: Plus {
480+ return Self :: Original ;
481+ }
482+
483+ match ( left. as_ref ( ) , right. as_ref ( ) ) {
484+ ( Expr :: Literal ( ..) , _) => Self :: Split {
485+ arg : * right,
486+ lit : * left,
487+ } ,
488+ ( _, Expr :: Literal ( ..) ) => Self :: Split {
489+ arg : * left,
490+ lit : * right,
491+ } ,
492+ _ => Self :: Original ,
493+ }
494+ }
349495}
350496
351497/// This accumulator computes SUM incrementally
0 commit comments