Skip to content

Commit 68cff5e

Browse files
committed
Allow AggregateUDFImpl::simplify to return non AggregateExpr
1 parent 0cdba08 commit 68cff5e

2 files changed

Lines changed: 163 additions & 64 deletions

File tree

datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
use std::sync::Arc;
2121

2222
use datafusion_common::tree_node::{Transformed, TreeNode};
23-
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
23+
use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError, Result};
2424
use datafusion_expr::Expr;
25-
use datafusion_expr::logical_plan::LogicalPlan;
25+
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection};
2626
use datafusion_expr::simplify::SimplifyContext;
27-
use datafusion_expr::utils::merge_schema;
27+
use datafusion_expr::utils::{
28+
columnize_expr, find_aggregate_exprs, grouping_set_to_exprlist, merge_schema,
29+
};
2830

2931
use crate::optimizer::ApplyOrder;
3032
use crate::utils::NamePreserver;
@@ -130,13 +132,17 @@ impl SimplifyExpressions {
130132
))
131133
};
132134

133-
plan.map_expressions(|expr| {
135+
let transformed = plan.map_expressions(|expr| {
134136
// Preserve the aliasing of grouping sets.
135137
if let Expr::GroupingSet(_) = &expr {
136138
expr.map_children(&mut rewrite_expr)
137139
} else {
138140
rewrite_expr(expr)
139141
}
142+
})?;
143+
144+
transformed.transform_data(|plan| {
145+
rewrite_aggregate_non_aggregate_aggr_expr(plan)
140146
})
141147
}
142148
}
@@ -148,6 +154,73 @@ impl SimplifyExpressions {
148154
}
149155
}
150156

157+
/// Rewrites
158+
/// `Aggregate(group_expr, aggr_expr=[non_agg_expr(sum(..), count(..), ..)])`
159+
/// into:
160+
/// `Projection(..)` over `Aggregate(group_expr, aggr_expr=[sum(..), count(..), ..])`.
161+
///
162+
/// Aggregate planning requires each aggregate slot to be an aggregate function
163+
/// (possibly aliased). Some UDAF simplifiers can return larger expressions that
164+
/// reference multiple aggregate functions.
165+
fn rewrite_aggregate_non_aggregate_aggr_expr(
166+
plan: LogicalPlan,
167+
) -> Result<Transformed<LogicalPlan>> {
168+
let LogicalPlan::Aggregate(Aggregate {
169+
input,
170+
group_expr,
171+
aggr_expr,
172+
..
173+
}) = plan
174+
else {
175+
return Ok(Transformed::no(plan));
176+
};
177+
178+
if aggr_expr.iter().all(is_top_level_aggregate_expr) {
179+
return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new(
180+
input, group_expr, aggr_expr,
181+
)?)));
182+
}
183+
184+
let inner_aggr_expr = find_aggregate_exprs(aggr_expr.iter());
185+
let inner_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
186+
Arc::clone(&input),
187+
group_expr.clone(),
188+
inner_aggr_expr,
189+
)?);
190+
let inner_aggregate = Arc::new(inner_aggregate);
191+
192+
let mut projection_exprs = aggregate_output_exprs(&group_expr)?;
193+
projection_exprs.extend(aggr_expr);
194+
let projection_exprs = projection_exprs
195+
.into_iter()
196+
.map(|expr| columnize_expr(expr, inner_aggregate.as_ref()))
197+
.collect::<Result<Vec<_>>>()?;
198+
199+
Ok(Transformed::yes(LogicalPlan::Projection(Projection::try_new(
200+
projection_exprs,
201+
inner_aggregate,
202+
)?)))
203+
}
204+
205+
fn is_top_level_aggregate_expr(expr: &Expr) -> bool {
206+
matches!(expr.clone().unalias_nested().data, Expr::AggregateFunction(_))
207+
}
208+
209+
fn aggregate_output_exprs(group_expr: &[Expr]) -> Result<Vec<Expr>> {
210+
let mut output_exprs = grouping_set_to_exprlist(group_expr)?
211+
.into_iter()
212+
.cloned()
213+
.collect::<Vec<_>>();
214+
215+
if matches!(group_expr, [Expr::GroupingSet(_)]) {
216+
output_exprs.push(Expr::Column(Column::from_name(
217+
Aggregate::INTERNAL_GROUPING_ID,
218+
)));
219+
}
220+
221+
Ok(output_exprs)
222+
}
223+
151224
#[cfg(test)]
152225
mod tests {
153226
use std::ops::Not;
@@ -158,7 +231,7 @@ mod tests {
158231
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
159232
use datafusion_expr::logical_plan::table_scan;
160233
use datafusion_expr::*;
161-
use datafusion_functions_aggregate::expr_fn::{max, min};
234+
use datafusion_functions_aggregate::expr_fn::{max, min, sum};
162235

163236
use crate::OptimizerContext;
164237
use crate::assert_optimized_plan_eq_snapshot;
@@ -258,6 +331,28 @@ mod tests {
258331
)
259332
}
260333

334+
#[test]
335+
fn test_simplify_udaf_to_non_aggregate_expr() -> Result<()> {
336+
let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
337+
let table_scan = table_scan(Some("test"), &schema, None)?
338+
.build()
339+
.expect("building scan");
340+
341+
let plan = LogicalPlanBuilder::from(table_scan)
342+
.aggregate(Vec::<Expr>::new(), vec![sum(col("a") + lit(2i64))])?
343+
.build()?;
344+
345+
assert_optimized_plan_equal!(
346+
plan,
347+
@r"
348+
Projection: sum(test.a) + Int64(2) * count(test.a) AS sum(test.a + Int64(2))
349+
Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]]
350+
TableScan: test
351+
"
352+
)?;
353+
Ok(())
354+
}
355+
261356
#[test]
262357
fn test_simplify_optimized_plan_with_or() -> Result<()> {
263358
let table_scan = test_table_scan();

0 commit comments

Comments
 (0)