Skip to content

Commit 2c89866

Browse files
committed
Add casts
1 parent c048aa0 commit 2c89866

2 files changed

Lines changed: 42 additions & 23 deletions

File tree

datafusion/functions-aggregate/src/sum.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use datafusion_common::types::{
3333
};
3434
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
3535
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36+
use datafusion_expr::expr_fn::cast;
3637
use datafusion_expr::function::{
3738
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
3839
};
@@ -391,12 +392,22 @@ fn sum_simplifier(mut agg: AggregateFunction, info: &SimplifyContext) -> Result<
391392
return Ok(Expr::AggregateFunction(agg));
392393
}
393394

395+
let lit_type = match &lit {
396+
Expr::Literal(value, _) => value.data_type(),
397+
_ => unreachable!("SplitResult::Split guarantees literal side"),
398+
};
399+
if lit_type == DataType::Null {
400+
return Ok(Expr::AggregateFunction(agg));
401+
}
402+
394403
// Rewrite to SUM(arg)
395404
agg.params.args = vec![arg.clone()];
396405
let sum_agg = Expr::AggregateFunction(agg);
397406

407+
let count_agg = cast(crate::count::count(arg), lit_type);
408+
398409
// sum(arg) + scalar * COUNT(arg)
399-
Ok(sum_agg + (lit * crate::count::count(arg)))
410+
Ok(sum_agg + (lit * count_agg))
400411
}
401412

402413
fn has_common_rewrite_arg(arg: &Expr, info: &SimplifyContext) -> bool {

datafusion/sqllogictest/test_files/aggregates_simplify.slt

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ query TT
106106
EXPLAIN SELECT SUM(column1 + 1), SUM(column1 + 2) FROM sum_simplify_t;
107107
----
108108
logical_plan
109-
01)Projection: sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + Int64(2) * count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(2))
110-
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
111-
03)----TableScan: sum_simplify_t projection=[column1]
109+
01)Projection: sum(sum_simplify_t.column1) + __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + Int64(2) * __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
110+
02)--Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS __common_expr_1, sum(sum_simplify_t.column1)
111+
03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
112+
04)------TableScan: sum_simplify_t projection=[column1]
112113
physical_plan
113114
01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1)@0 + 2 * count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(2))]
114115
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
@@ -124,9 +125,10 @@ query TT
124125
EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 2) FROM sum_simplify_t;
125126
----
126127
logical_plan
127-
01)Projection: sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(Int64(1) + sum_simplify_t.column1), sum(sum_simplify_t.column1) + Int64(2) * count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(2))
128-
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
129-
03)----TableScan: sum_simplify_t projection=[column1]
128+
01)Projection: sum(sum_simplify_t.column1) + __common_expr_1 AS sum(Int64(1) + sum_simplify_t.column1), sum(sum_simplify_t.column1) + Int64(2) * __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
129+
02)--Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS __common_expr_1, sum(sum_simplify_t.column1)
130+
03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
131+
04)------TableScan: sum_simplify_t projection=[column1]
130132
physical_plan
131133
01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + count(sum_simplify_t.column1)@1 as sum(Int64(1) + sum_simplify_t.column1), sum(sum_simplify_t.column1)@0 + 2 * count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(2))]
132134
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
@@ -233,20 +235,24 @@ physical_plan
233235
03)----DataSourceExec: partitions=1, partition_sizes=[1]
234236

235237
# volatile aggregate arguments
236-
query error DataFusion error: Arrow error: Invalid argument error: Invalid arithmetic operation: Float64 \* Int64
238+
query B
237239
SELECT SUM(random() + 1) < SUM(random() + 2) FROM sum_simplify_t;
240+
----
241+
true
238242

239243
query TT
240244
EXPLAIN SELECT SUM(random() + 1) < SUM(random() + 2) FROM sum_simplify_t;
241245
----
242246
logical_plan
243-
01)Projection: sum(random()) + Float64(2) * count(random()) > sum(random()) + CAST(count(random()) AS Float64) AS sum(random() + Int64(1)) < sum(random() + Int64(2))
244-
02)--Aggregate: groupBy=[[]], aggr=[[sum(random()), count(random())]]
245-
03)----TableScan: sum_simplify_t projection=[]
247+
01)Projection: sum(random()) + Float64(2) * __common_expr_1 > sum(random()) + __common_expr_1 AS sum(random() + Int64(1)) < sum(random() + Int64(2))
248+
02)--Projection: CAST(count(random()) AS Float64) AS __common_expr_1, sum(random())
249+
03)----Aggregate: groupBy=[[]], aggr=[[sum(random()), count(random())]]
250+
04)------TableScan: sum_simplify_t projection=[]
246251
physical_plan
247-
01)ProjectionExec: expr=[sum(random())@0 + 2 * count(random())@1 > sum(random())@0 + CAST(count(random())@1 AS Float64) as sum(random() + Int64(1)) < sum(random() + Int64(2))]
248-
02)--AggregateExec: mode=Single, gby=[], aggr=[sum(random()), count(random())]
249-
03)----DataSourceExec: partitions=1, partition_sizes=[1]
252+
01)ProjectionExec: expr=[sum(random())@1 + 2 * __common_expr_1@0 > sum(random())@1 + __common_expr_1@0 as sum(random() + Int64(1)) < sum(random() + Int64(2))]
253+
02)--ProjectionExec: expr=[CAST(count(random())@1 AS Float64) as __common_expr_1, sum(random())@0 as sum(random())]
254+
03)----AggregateExec: mode=Single, gby=[], aggr=[sum(random()), count(random())]
255+
04)------DataSourceExec: partitions=1, partition_sizes=[1]
250256

251257
# Checks grouped aggregates with explicit ORDER BY return deterministic row order.
252258
query III
@@ -261,9 +267,10 @@ EXPLAIN SELECT column2, SUM(column1 + 1), SUM(column1 + 2) FROM sum_simplify_t G
261267
----
262268
logical_plan
263269
01)Sort: sum_simplify_t.column2 DESC NULLS LAST
264-
02)--Projection: sum_simplify_t.column2, sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + Int64(2) * count(sum_simplify_t.column1) AS sum(sum_simplify_t.column1 + Int64(2))
265-
03)----Aggregate: groupBy=[[sum_simplify_t.column2]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
266-
04)------TableScan: sum_simplify_t projection=[column1, column2]
270+
02)--Projection: sum_simplify_t.column2, sum(sum_simplify_t.column1) + __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + Int64(2) * __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
271+
03)----Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS __common_expr_1, sum_simplify_t.column2, sum(sum_simplify_t.column1)
272+
04)------Aggregate: groupBy=[[sum_simplify_t.column2]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
273+
05)--------TableScan: sum_simplify_t projection=[column1, column2]
267274
physical_plan
268275
01)SortPreservingMergeExec: [column2@0 DESC NULLS LAST]
269276
02)--SortExec: expr=[column2@0 DESC NULLS LAST], preserve_partitioning=[true]
@@ -284,7 +291,7 @@ EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 1) FROM sum_simplify_t;
284291
----
285292
logical_plan
286293
01)Projection: __common_expr_1 AS sum(Int64(1) + sum_simplify_t.column1), __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1))
287-
02)--Projection: sum(sum_simplify_t.column1) + count(sum_simplify_t.column1) AS __common_expr_1
294+
02)--Projection: sum(sum_simplify_t.column1) + CAST(count(sum_simplify_t.column1) AS Int64) AS __common_expr_1
288295
03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
289296
04)------TableScan: sum_simplify_t projection=[column1]
290297
physical_plan
@@ -314,15 +321,16 @@ EXPLAIN SELECT arrow_typeof(SUM(val + 1)), SUM(val + 1), SUM(val + 2) FROM tbl;
314321
----
315322
logical_plan
316323
01)Projection: arrow_typeof(sum(tbl.val + Int64(1))), sum(tbl.val + Int64(1)), sum(tbl.val + Int64(2))
317-
02)--Projection: sum(tbl.val) + count(tbl.val) AS sum(tbl.val + Int64(1)), sum(tbl.val) + Int64(2) * count(tbl.val) AS sum(tbl.val + Int64(2))
318-
03)----Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS tbl.val), count(__common_expr_1 AS tbl.val)]]
319-
04)------Projection: CAST(tbl.val AS Int64) AS __common_expr_1
320-
05)--------TableScan: tbl projection=[val]
324+
02)--Projection: sum(tbl.val) + __common_expr_1 AS sum(tbl.val + Int64(1)), sum(tbl.val) + Int64(2) * __common_expr_1 AS sum(tbl.val + Int64(2))
325+
03)----Projection: CAST(count(tbl.val) AS Int64) AS __common_expr_1, sum(tbl.val)
326+
04)------Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_2 AS tbl.val), count(__common_expr_2 AS tbl.val)]]
327+
05)--------Projection: CAST(tbl.val AS Int64) AS __common_expr_2
328+
06)----------TableScan: tbl projection=[val]
321329
physical_plan
322330
01)ProjectionExec: expr=[arrow_typeof(sum(tbl.val + Int64(1))@0) as arrow_typeof(sum(tbl.val + Int64(1))), sum(tbl.val + Int64(1))@0 as sum(tbl.val + Int64(1)), sum(tbl.val + Int64(2))@1 as sum(tbl.val + Int64(2))]
323331
02)--ProjectionExec: expr=[sum(tbl.val)@0 + count(tbl.val)@1 as sum(tbl.val + Int64(1)), sum(tbl.val)@0 + 2 * count(tbl.val)@1 as sum(tbl.val + Int64(2))]
324332
03)----AggregateExec: mode=Single, gby=[], aggr=[sum(tbl.val), count(tbl.val)]
325-
04)------ProjectionExec: expr=[CAST(val@0 AS Int64) as __common_expr_1]
333+
04)------ProjectionExec: expr=[CAST(val@0 AS Int64) as __common_expr_2]
326334
05)--------DataSourceExec: partitions=1, partition_sizes=[2]
327335

328336
# Checks equivalent rewritten form (SUM + COUNT terms) matches transformed SUM semantics.

0 commit comments

Comments
 (0)