Skip to content

Commit 6d3a846

Browse files
authored
Rewrite SUM(expr + scalar) --> SUM(expr) + scalar*COUNT(expr) (#20749)
## Which issue does this PR close? - Part of #18489 - Closes #20180 - Closes #15524 - Replaces #20665 ## Rationale for this change I [want DataFusion to be the fastest parquet engine on ClickBench](#18489). One of the queries where DataFusion is significantly slower is Query 29 which has a very strange pattern of many aggregate functions that are offset by a constant: https://github.com/apache/datafusion/blob/0ca9d6586a43c323525b2e299448e0f1af4d6195/benchmarks/queries/clickbench/queries/q29.sql#L4 This is not a pattern I have ever seen in a real query, but it seems like the engine currently at the top of the ClickBench leaderboard has a special case for this pattern. ClickHouse probably does too. See - duckdb/duckdb#15017 - Discussion on #15524 Thus I reluctantly conclude that we should have one too. ## What changes are included in this PR? This is an alternate to my first attempt. - #20665 In particular, since this is such a ClickBench specific rule, I wanted to 1. Minimize the downstream API / upgrade impact (aka not change existing APIs) 2. Optimize performance for the case where this rewrite will not apply (most times) 1. Add a rewrite `SUM(expr + scalar)` --> `SUM(expr) + scalar*COUNT(expr)` 3. Tests for same Note there are quite a few other ideas to potentially make this more general on #15524 but I am going with the simple thing of making it work for the usecase we have in hand (ClickBench) ## Are these changes tested? Yes, new tests are added ## Are there any user-facing changes? Faster performance 🚀 ``` │ QQuery 29 │ 1012.63 ms │ 139.02 ms │ +7.28x faster │ ```
1 parent 5db04b8 commit 6d3a846

9 files changed

Lines changed: 574 additions & 44 deletions

File tree

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ impl Alias {
600600
}
601601
}
602602

603-
/// Binary expression
603+
/// Binary expression for [`Expr::BinaryExpr`]
604604
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
605605
pub struct BinaryExpr {
606606
/// Left-hand side of the expression

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3490,7 +3490,9 @@ pub struct Aggregate {
34903490
pub input: Arc<LogicalPlan>,
34913491
/// Grouping expressions
34923492
pub group_expr: Vec<Expr>,
3493-
/// Aggregate expressions
3493+
/// Aggregate expressions.
3494+
///
3495+
/// Note these *must* be either [`Expr::AggregateFunction`] or [`Expr::Alias`]
34943496
pub aggr_expr: Vec<Expr>,
34953497
/// The schema description of the aggregate output
34963498
pub schema: DFSchemaRef,

datafusion/expr/src/udaf.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use arrow::datatypes::{DataType, Field, FieldRef};
2828

2929
use datafusion_common::{Result, ScalarValue, Statistics, exec_err, not_impl_err};
3030
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
31+
use datafusion_expr_common::operator::Operator;
3132
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
3233

3334
use crate::expr::{
@@ -301,6 +302,21 @@ impl AggregateUDF {
301302
self.inner.simplify()
302303
}
303304

305+
/// Rewrite aggregate to have simpler arguments
306+
///
307+
/// See [`AggregateUDFImpl::simplify_expr_op_literal`] for more details
308+
pub fn simplify_expr_op_literal(
309+
&self,
310+
agg_function: &AggregateFunction,
311+
arg: &Expr,
312+
op: Operator,
313+
lit: &Expr,
314+
arg_is_left: bool,
315+
) -> Result<Option<Expr>> {
316+
self.inner
317+
.simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
318+
}
319+
304320
/// Returns true if the function is max, false if the function is min
305321
/// None in all other cases, used in certain optimizations for
306322
/// or aggregate
@@ -691,6 +707,74 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
691707
None
692708
}
693709

710+
/// Rewrite the aggregate to have simpler arguments
711+
///
712+
/// This query pattern is not common in most real workloads, and most
713+
/// aggregate implementations can safely ignore it. This API is included in
714+
/// DataFusion because it is important for ClickBench Q29. See backstory
715+
/// on <https://github.com/apache/datafusion/issues/15524>
716+
///
717+
/// # Rewrite Overview
718+
///
719+
/// The idea is to rewrite multiple aggregates with "complex arguments" into
720+
/// ones with simpler arguments that can be optimized by common subexpression
721+
/// elimination (CSE). At a high level the rewrite looks like
722+
///
723+
/// * `Aggregate(SUM(x + 1), SUM(x + 2), ...)`
724+
///
725+
/// Into
726+
///
727+
/// * `Aggregate(SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...)`
728+
///
729+
/// While this rewrite may seem worse (slower) than the original as it
730+
/// computes *more* aggregate expressions, the common subexpression
731+
/// elimination (CSE) can then reduce the number of distinct aggregates the
732+
/// query actually needs to compute with a rewrite like
733+
///
734+
/// * `Projection(_A + 1*_B, _A + 2*_B)`
735+
/// * ` Aggregate(_A = SUM(x), _B = COUNT(x))`
736+
///
737+
/// This optimization is extremely important for ClickBench Q29, which has 90
738+
/// such expressions for some reason, and so this optimization results in
739+
/// only two aggregates being needed. The DataFusion optimizer will invoke
740+
/// this method when it detects multiple aggregates in a query that share
741+
/// arguments of the form `<arg> <op> <literal>`.
742+
///
743+
/// # API
744+
///
745+
/// If `agg_function` supports the rewrite, it should return a semantically
746+
/// equivalent expression (likely with more aggregate expressions, but
747+
/// simpler arguments)
748+
///
749+
/// This is only called when:
750+
/// 1. There are no "special" aggregate params (filters, null handling, etc)
751+
/// 2. Aggregate functions with exactly one [`Expr`] argument
752+
/// 3. There are no volatile expressions
753+
///
754+
/// Arguments
755+
/// * `agg_function`: the original aggregate function detected with complex
756+
/// arguments.
757+
/// * `arg`: The common argument shared across multiple aggregates (e.g. `x`
758+
/// in the example above)
759+
/// * `op`: the operator between the common argument and the literal (e.g.
760+
/// `+` in `x + 1` or `1 + x`)
761+
/// * `lit`: the literal argument (e.g. `1` or `2` in the example above)
762+
/// * `arg_is_left`: whether the common argument is on the left or right of
763+
/// the operator (e.g. `true` for `x + 1` and false for `1 + x`)
764+
///
765+
/// The default implementation returns `None`, which is what most aggregates
766+
/// should do.
767+
fn simplify_expr_op_literal(
768+
&self,
769+
_agg_function: &AggregateFunction,
770+
_arg: &Expr,
771+
_op: Operator,
772+
_lit: &Expr,
773+
_arg_is_left: bool,
774+
) -> Result<Option<Expr>> {
775+
Ok(None)
776+
}
777+
694778
/// Returns the reverse expression of the aggregate function.
695779
fn reverse_expr(&self) -> ReversedUDAF {
696780
ReversedUDAF::NotSupported
@@ -1243,6 +1327,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
12431327
self.inner.simplify()
12441328
}
12451329

1330+
fn simplify_expr_op_literal(
1331+
&self,
1332+
agg_function: &AggregateFunction,
1333+
arg: &Expr,
1334+
op: Operator,
1335+
lit: &Expr,
1336+
arg_is_left: bool,
1337+
) -> Result<Option<Expr>> {
1338+
self.inner
1339+
.simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
1340+
}
1341+
12461342
fn reverse_expr(&self) -> ReversedUDAF {
12471343
self.inner.reverse_expr()
12481344
}

datafusion/functions-aggregate/src/sum.rs

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,20 @@ use arrow::datatypes::{
2727
DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
2828
Float64Type, Int64Type, TimeUnit, UInt64Type,
2929
};
30+
use datafusion_common::internal_err;
3031
use datafusion_common::types::{
3132
NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
3233
logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3334
};
3435
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
36+
use datafusion_expr::expr::AggregateFunction;
37+
use datafusion_expr::expr_fn::cast;
3538
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3639
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
3740
use datafusion_expr::{
3841
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
39-
ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
40-
Volatility,
42+
Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
43+
TypeSignatureClass, Volatility,
4144
};
4245
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
4346
use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
@@ -54,7 +57,7 @@ make_udaf_expr_and_func!(
5457
);
5558

5659
pub fn sum_distinct(expr: Expr) -> Expr {
57-
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
60+
Expr::AggregateFunction(AggregateFunction::new_udf(
5861
sum_udaf(),
5962
vec![expr],
6063
true,
@@ -346,6 +349,47 @@ impl AggregateUDFImpl for Sum {
346349
_ => SetMonotonicity::NotMonotonic,
347350
}
348351
}
352+
353+
/// Implement ClickBench Q29 specific optimization:
354+
/// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
355+
///
356+
/// See background on [`AggregateUDFImpl::simplify_expr_op_literal`]
357+
fn simplify_expr_op_literal(
358+
&self,
359+
agg_function: &AggregateFunction,
360+
arg: &Expr,
361+
op: Operator,
362+
lit: &Expr,
363+
// Only support '+' so the order of the args doesn't matter
364+
_arg_is_left: bool,
365+
) -> Result<Option<Expr>> {
366+
if op != Operator::Plus {
367+
return Ok(None);
368+
}
369+
370+
let lit_type = match &lit {
371+
Expr::Literal(value, _) => value.data_type(),
372+
_ => {
373+
return internal_err!(
374+
"Sum::simplify_expr_op_literal got a non literal argument"
375+
);
376+
}
377+
};
378+
if lit_type == DataType::Null {
379+
return Ok(None);
380+
}
381+
382+
// Build up SUM(arg)
383+
let mut sum_agg = agg_function.clone();
384+
sum_agg.params.args = vec![arg.clone()];
385+
let sum_agg = Expr::AggregateFunction(sum_agg);
386+
387+
// COUNT(arg) - cast to the correct type
388+
let count_agg = cast(crate::count::count(arg.clone()), lit_type);
389+
390+
// SUM(arg) + lit * COUNT(arg)
391+
Ok(Some(sum_agg + (lit.clone() * count_agg)))
392+
}
349393
}
350394

351395
/// This accumulator computes SUM incrementally

0 commit comments

Comments
 (0)