Skip to content

Commit dadf892

Browse files
committed
Add rewrite SUM(expr+C) --> SUM(expr) + COUNT(expr)*C
1 parent 848cd63 commit dadf892

2 files changed

Lines changed: 92 additions & 6 deletions

File tree

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ impl Alias {
600600
}
601601
}
602602

603-
/// Binary expression
603+
/// Binary expression for [`Expr::BinaryExpr`]
604604
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
605605
pub struct BinaryExpr {
606606
/// Left-hand side of the expression

datafusion/functions-aggregate/src/sum.rs

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@ use datafusion_common::types::{
3232
logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3333
};
3434
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
35-
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35+
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36+
use datafusion_expr::function::{
37+
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
38+
};
39+
use datafusion_expr::simplify::SimplifyContext;
3640
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
3741
use datafusion_expr::{
38-
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
39-
ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
40-
Volatility,
42+
Accumulator, AggregateUDFImpl, BinaryExpr, Coercion, Documentation, Expr,
43+
GroupsAccumulator, Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
44+
TypeSignatureClass, Volatility,
4145
};
4246
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
4347
use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
@@ -54,7 +58,7 @@ make_udaf_expr_and_func!(
5458
);
5559

5660
pub fn sum_distinct(expr: Expr) -> Expr {
57-
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
61+
Expr::AggregateFunction(AggregateFunction::new_udf(
5862
sum_udaf(),
5963
vec![expr],
6064
true,
@@ -346,6 +350,88 @@ impl AggregateUDFImpl for Sum {
346350
_ => SetMonotonicity::NotMonotonic,
347351
}
348352
}
353+
354+
/// Simplification Rules
355+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
356+
Some(Box::new(sum_simplifier))
357+
}
358+
}
359+
360+
/// Implement ClickBench Q29 specific optimization:
361+
/// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
362+
///
363+
/// Backstory: TODO
364+
///
365+
fn sum_simplifier(mut agg: AggregateFunction, _info: &SimplifyContext) -> Result<Expr> {
366+
// Explicitly destructure to ensure we check all relevant fields
367+
let AggregateFunctionParams {
368+
args,
369+
distinct,
370+
filter,
371+
order_by,
372+
null_treatment,
373+
} = &agg.params;
374+
375+
if *distinct
376+
|| filter.is_some()
377+
|| !order_by.is_empty()
378+
|| null_treatment.is_some()
379+
|| args.len() != 1
380+
{
381+
return Ok(Expr::AggregateFunction(agg));
382+
}
383+
384+
// otherwise check the arguments if they are <col> <op> scalar
385+
let (arg, lit) = match SplitResult::new(agg.params.args.swap_remove(0)) {
386+
SplitResult::Original(expr) => {
387+
agg.params.args.push(expr); // put it back
388+
return Ok(Expr::AggregateFunction(agg));
389+
}
390+
SplitResult::Split { arg, lit } => (arg, lit),
391+
};
392+
393+
// Rewrite to SUM(arg)
394+
agg.params.args.push(arg.clone());
395+
let sum_agg = Expr::AggregateFunction(agg);
396+
397+
// sum(arg) + scalar * COUNT(arg)
398+
Ok(sum_agg + (lit * crate::count::count(arg)))
399+
}
400+
401+
/// Result of trying to split an expression into an arg and constant
402+
#[derive(Debug, Clone)]
403+
enum SplitResult {
404+
/// if the expression is either of
405+
/// * `<arg> <op> <lit>`
406+
/// * `<lit> <op> <arg>`
407+
///
408+
/// When `op` is `+`
409+
Split { arg: Expr, lit: Expr },
410+
/// If the expression is something else
411+
Original(Expr),
412+
}
413+
414+
impl SplitResult {
415+
fn new(expr: Expr) -> Self {
416+
let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
417+
return Self::Original(expr);
418+
};
419+
if op != Operator::Plus {
420+
return Self::Original(Expr::BinaryExpr(BinaryExpr { left, op, right }));
421+
}
422+
423+
match (left.as_ref(), right.as_ref()) {
424+
(Expr::Literal(..), _) => Self::Split {
425+
arg: *right,
426+
lit: *left,
427+
},
428+
(_, Expr::Literal(..)) => Self::Split {
429+
arg: *left,
430+
lit: *right,
431+
},
432+
_ => Self::Original(Expr::BinaryExpr(BinaryExpr { left, op, right })),
433+
}
434+
}
349435
}
350436

351437
/// This accumulator computes SUM incrementally

0 commit comments

Comments
 (0)