Skip to content

Commit 25350af

Browse files
committed
Add rewrite SUM(expr+C) --> SUM(expr) + COUNT(expr)*C
1 parent 4310ec8 commit 25350af

6 files changed

Lines changed: 353 additions & 50 deletions

File tree

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ impl Alias {
600600
}
601601
}
602602

603-
/// Binary expression
603+
/// Binary expression for [`Expr::BinaryExpr`]
604604
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
605605
pub struct BinaryExpr {
606606
/// Left-hand side of the expression

datafusion/expr/src/simplify.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub struct SimplifyContext {
3838
schema: DFSchemaRef,
3939
query_execution_start_time: Option<DateTime<Utc>>,
4040
config_options: Arc<ConfigOptions>,
41+
aggregate_exprs: Option<Arc<Vec<Expr>>>,
4142
}
4243

4344
impl Default for SimplifyContext {
@@ -46,6 +47,7 @@ impl Default for SimplifyContext {
4647
schema: Arc::new(DFSchema::empty()),
4748
query_execution_start_time: None,
4849
config_options: Arc::new(ConfigOptions::default()),
50+
aggregate_exprs: None,
4951
}
5052
}
5153
}
@@ -78,6 +80,12 @@ impl SimplifyContext {
7880
self
7981
}
8082

83+
/// Set aggregate expressions from the containing aggregate node, if any.
84+
pub fn with_aggregate_exprs(mut self, aggregate_exprs: Arc<Vec<Expr>>) -> Self {
85+
self.aggregate_exprs = Some(aggregate_exprs);
86+
self
87+
}
88+
8189
/// Returns the schema
8290
pub fn schema(&self) -> &DFSchemaRef {
8391
&self.schema
@@ -108,6 +116,11 @@ impl SimplifyContext {
108116
pub fn config_options(&self) -> &Arc<ConfigOptions> {
109117
&self.config_options
110118
}
119+
120+
/// Returns aggregate expressions from the containing aggregate node, if any.
121+
pub fn aggregate_exprs(&self) -> Option<&[Expr]> {
122+
self.aggregate_exprs.as_deref().map(Vec::as_slice)
123+
}
111124
}
112125

113126
/// Was the expression simplified?

datafusion/functions-aggregate/src/sum.rs

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,17 @@ use datafusion_common::types::{
3232
logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3333
};
3434
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
35-
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35+
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36+
use datafusion_expr::expr_fn::cast;
37+
use datafusion_expr::function::{
38+
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39+
};
40+
use datafusion_expr::simplify::SimplifyContext;
3641
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
3742
use datafusion_expr::{
38-
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
39-
ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
40-
Volatility,
43+
Accumulator, AggregateUDFImpl, BinaryExpr, Coercion, Documentation, Expr,
44+
GroupsAccumulator, Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
45+
TypeSignatureClass, Volatility,
4146
};
4247
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
4348
use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
@@ -54,7 +59,7 @@ make_udaf_expr_and_func!(
5459
);
5560

5661
pub fn sum_distinct(expr: Expr) -> Expr {
57-
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
62+
Expr::AggregateFunction(AggregateFunction::new_udf(
5863
sum_udaf(),
5964
vec![expr],
6065
true,
@@ -346,6 +351,147 @@ impl AggregateUDFImpl for Sum {
346351
_ => SetMonotonicity::NotMonotonic,
347352
}
348353
}
354+
355+
/// Simplification Rules
356+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
357+
Some(Box::new(sum_simplifier))
358+
}
359+
}
360+
361+
/// Implement ClickBench Q29 specific optimization:
362+
/// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
363+
///
364+
/// Backstory: TODO
365+
///
366+
fn sum_simplifier(mut agg: AggregateFunction, info: &SimplifyContext) -> Result<Expr> {
367+
// Explicitly destructure to ensure we check all relevant fields
368+
let AggregateFunctionParams {
369+
args,
370+
distinct,
371+
filter,
372+
order_by,
373+
null_treatment,
374+
} = &agg.params;
375+
376+
if *distinct
377+
|| filter.is_some()
378+
|| !order_by.is_empty()
379+
|| null_treatment.is_some()
380+
|| args.len() != 1
381+
{
382+
return Ok(Expr::AggregateFunction(agg));
383+
}
384+
385+
// otherwise check the arguments if they are <arg> + <literal>
386+
let (arg, lit) = match SplitResult::new(agg.params.args[0].clone()) {
387+
SplitResult::Original => return Ok(Expr::AggregateFunction(agg)),
388+
SplitResult::Split { arg, lit } => (arg, lit),
389+
};
390+
391+
if !has_common_rewrite_arg(&arg, info) {
392+
return Ok(Expr::AggregateFunction(agg));
393+
}
394+
395+
let lit_type = match &lit {
396+
Expr::Literal(value, _) => value.data_type(),
397+
_ => unreachable!("SplitResult::Split guarantees literal side"),
398+
};
399+
if lit_type == DataType::Null {
400+
return Ok(Expr::AggregateFunction(agg));
401+
}
402+
403+
// Rewrite to SUM(arg)
404+
agg.params.args = vec![arg.clone()];
405+
let sum_agg = Expr::AggregateFunction(agg);
406+
407+
let count_agg = cast(crate::count::count(arg), lit_type);
408+
409+
// sum(arg) + scalar * COUNT(arg)
410+
Ok(sum_agg + (lit * count_agg))
411+
}
412+
413+
fn has_common_rewrite_arg(arg: &Expr, info: &SimplifyContext) -> bool {
414+
let Some(aggregate_exprs) = info.aggregate_exprs() else {
415+
// Only apply this rewrite in the context of an Aggregate node where
416+
// sibling aggregate expressions are known.
417+
return false;
418+
};
419+
420+
aggregate_exprs
421+
.iter()
422+
.filter_map(sum_rewrite_candidate_arg)
423+
.filter(|candidate_arg| candidate_arg == arg)
424+
.take(2)
425+
.count()
426+
> 1
427+
}
428+
429+
fn sum_rewrite_candidate_arg(expr: &Expr) -> Option<Expr> {
430+
let Expr::AggregateFunction(aggregate_fn) = expr.clone().unalias_nested().data else {
431+
return None;
432+
};
433+
if !aggregate_fn.func.name().eq_ignore_ascii_case("sum") {
434+
return None;
435+
}
436+
437+
let AggregateFunctionParams {
438+
args,
439+
distinct,
440+
filter,
441+
order_by,
442+
null_treatment,
443+
} = &aggregate_fn.params;
444+
445+
if *distinct
446+
|| filter.is_some()
447+
|| !order_by.is_empty()
448+
|| null_treatment.is_some()
449+
|| args.len() != 1
450+
{
451+
return None;
452+
}
453+
454+
match SplitResult::new(args[0].clone()) {
455+
SplitResult::Split { arg, .. } => Some(arg),
456+
SplitResult::Original => None,
457+
}
458+
}
459+
460+
/// Result of trying to split an expression into an arg and constant
461+
#[expect(clippy::large_enum_variant)]
462+
#[derive(Debug, Clone)]
463+
enum SplitResult {
464+
/// if the expression is either of
465+
/// * `<arg> <op> <lit>`
466+
/// * `<lit> <op> <arg>`
467+
///
468+
/// When `op` is `+`
469+
Split { arg: Expr, lit: Expr },
470+
/// If the expression is something else
471+
Original,
472+
}
473+
474+
impl SplitResult {
475+
fn new(expr: Expr) -> Self {
476+
let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
477+
return Self::Original;
478+
};
479+
if op != Operator::Plus {
480+
return Self::Original;
481+
}
482+
483+
match (left.as_ref(), right.as_ref()) {
484+
(Expr::Literal(..), _) => Self::Split {
485+
arg: *right,
486+
lit: *left,
487+
},
488+
(_, Expr::Literal(..)) => Self::Split {
489+
arg: *left,
490+
lit: *right,
491+
},
492+
_ => Self::Original,
493+
}
494+
}
349495
}
350496

351497
/// This accumulator computes SUM incrementally

0 commit comments

Comments
 (0)