Skip to content

Commit 848cd63

Browse files
authored
Eliminate deterministic group by keys with deterministic transformations (#20706)
## Which issue does this PR close? - Helps with #18489 ## Rationale for this change Make queries go faster like this randomly selected one: ``` SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ``` ## What changes are included in this PR? ## Are these changes tested? ## Are there any user-facing changes?
1 parent 1f0232c commit 848cd63

2 files changed

Lines changed: 99 additions & 29 deletions

File tree

datafusion/optimizer/src/eliminate_group_by_constant.rs

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause
18+
//! [`EliminateGroupByConstant`] removes constant and functionally redundant
19+
//! expressions from `GROUP BY` clause
1920
use crate::optimizer::ApplyOrder;
2021
use crate::{OptimizerConfig, OptimizerRule};
2122

23+
use std::collections::HashSet;
24+
2225
use datafusion_common::Result;
2326
use datafusion_common::tree_node::Transformed;
2427
use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility};
@@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant {
4750
) -> Result<Transformed<LogicalPlan>> {
4851
match plan {
4952
LogicalPlan::Aggregate(aggregate) => {
50-
let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate
53+
// Collect bare column references in GROUP BY
54+
let group_by_columns: HashSet<&datafusion_common::Column> = aggregate
5155
.group_expr
5256
.iter()
53-
.partition(|expr| is_constant_expression(expr));
54-
55-
// If no constant expressions found (nothing to optimize) or
56-
// constant expression is the only expression in aggregate,
57-
// optimization is skipped
58-
if const_group_expr.is_empty()
59-
|| (!const_group_expr.is_empty()
60-
&& nonconst_group_expr.is_empty()
61-
&& aggregate.aggr_expr.is_empty())
57+
.filter_map(|expr| match expr {
58+
Expr::Column(c) => Some(c),
59+
_ => None,
60+
})
61+
.collect();
62+
63+
let (redundant, required): (Vec<_>, Vec<_>) = aggregate
64+
.group_expr
65+
.iter()
66+
.partition(|expr| is_redundant_group_expr(expr, &group_by_columns));
67+
68+
if redundant.is_empty()
69+
|| (required.is_empty() && aggregate.aggr_expr.is_empty())
6270
{
6371
return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
6472
}
6573

6674
let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
6775
aggregate.input,
68-
nonconst_group_expr.into_iter().cloned().collect(),
76+
required.into_iter().cloned().collect(),
6977
aggregate.aggr_expr.clone(),
7078
)?);
7179

@@ -91,23 +99,47 @@ impl OptimizerRule for EliminateGroupByConstant {
9199
}
92100
}
93101

94-
/// Checks if expression is constant, and can be eliminated from group by.
95-
///
96-
/// Intended to be used only within this rule, helper function, which heavily
97-
/// relies on `SimplifyExpressions` result.
98-
fn is_constant_expression(expr: &Expr) -> bool {
102+
/// Checks if a GROUP BY expression is redundant (can be removed without
103+
/// changing grouping semantics). An expression is redundant if it is a
104+
/// deterministic function of constants and columns already present as bare
105+
/// column references in the GROUP BY.
106+
fn is_redundant_group_expr(
107+
expr: &Expr,
108+
group_by_columns: &HashSet<&datafusion_common::Column>,
109+
) -> bool {
110+
// Bare column references are never redundant - they define the grouping
111+
if matches!(expr, Expr::Column(_)) {
112+
return false;
113+
}
114+
is_deterministic_of(expr, group_by_columns)
115+
}
116+
117+
/// Returns true if `expr` is a deterministic expression whose only column
118+
/// references are contained in `known_columns`.
119+
fn is_deterministic_of(
120+
expr: &Expr,
121+
known_columns: &HashSet<&datafusion_common::Column>,
122+
) -> bool {
99123
match expr {
100-
Expr::Alias(e) => is_constant_expression(&e.expr),
124+
Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns),
125+
Expr::Column(c) => known_columns.contains(c),
126+
Expr::Literal(_, _) => true,
101127
Expr::BinaryExpr(e) => {
102-
is_constant_expression(&e.left) && is_constant_expression(&e.right)
128+
is_deterministic_of(&e.left, known_columns)
129+
&& is_deterministic_of(&e.right, known_columns)
103130
}
104-
Expr::Literal(_, _) => true,
105131
Expr::ScalarFunction(e) => {
106132
matches!(
107133
e.func.signature().volatility,
108134
Volatility::Immutable | Volatility::Stable
109-
) && e.args.iter().all(is_constant_expression)
135+
) && e
136+
.args
137+
.iter()
138+
.all(|arg| is_deterministic_of(arg, known_columns))
110139
}
140+
Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns),
141+
Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns),
142+
Expr::Negative(e) => is_deterministic_of(e, known_columns),
111143
_ => false,
112144
}
113145
}
@@ -268,6 +300,43 @@ mod tests {
268300
")
269301
}
270302

303+
#[test]
304+
fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> {
305+
let scan = test_table_scan()?;
306+
// GROUP BY a, a - 1, a - 2, a - 3 -> GROUP BY a
307+
let plan = LogicalPlanBuilder::from(scan)
308+
.aggregate(
309+
vec![
310+
col("a"),
311+
col("a") - lit(1u32),
312+
col("a") - lit(2u32),
313+
col("a") - lit(3u32),
314+
],
315+
vec![count(col("c"))],
316+
)?
317+
.build()?;
318+
319+
assert_optimized_plan_equal!(plan, @r"
320+
Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c)
321+
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
322+
TableScan: test
323+
")
324+
}
325+
326+
#[test]
327+
fn test_no_eliminate_independent_columns() -> Result<()> {
328+
// GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column)
329+
let scan = test_table_scan()?;
330+
let plan = LogicalPlanBuilder::from(scan)
331+
.aggregate(vec![col("a"), col("b") - lit(1u32)], vec![count(col("c"))])?
332+
.build()?;
333+
334+
assert_optimized_plan_equal!(plan, @r"
335+
Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]]
336+
TableScan: test
337+
")
338+
}
339+
271340
#[test]
272341
fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
273342
let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(

datafusion/sqllogictest/test_files/clickbench.slt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -959,19 +959,20 @@ EXPLAIN SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT
959959
----
960960
logical_plan
961961
01)Sort: c DESC NULLS FIRST, fetch=10
962-
02)--Projection: hits.ClientIP, hits.ClientIP - Int64(1), hits.ClientIP - Int64(2), hits.ClientIP - Int64(3), count(Int64(1)) AS count(*) AS c
963-
03)----Aggregate: groupBy=[[hits.ClientIP, __common_expr_1 AS hits.ClientIP - Int64(1), __common_expr_1 AS hits.ClientIP - Int64(2), __common_expr_1 AS hits.ClientIP - Int64(3)]], aggr=[[count(Int64(1))]]
964-
04)------Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, hits.ClientIP
962+
02)--Projection: hits.ClientIP, __common_expr_1 - Int64(1) AS hits.ClientIP - Int64(1), __common_expr_1 - Int64(2) AS hits.ClientIP - Int64(2), __common_expr_1 - Int64(3) AS hits.ClientIP - Int64(3), count(Int64(1)) AS c
963+
03)----Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, hits.ClientIP, count(Int64(1))
964+
04)------Aggregate: groupBy=[[hits.ClientIP]], aggr=[[count(Int64(1))]]
965965
05)--------SubqueryAlias: hits
966966
06)----------TableScan: hits_raw projection=[ClientIP]
967967
physical_plan
968968
01)SortPreservingMergeExec: [c@4 DESC], fetch=10
969969
02)--SortExec: TopK(fetch=10), expr=[c@4 DESC], preserve_partitioning=[true]
970-
03)----ProjectionExec: expr=[ClientIP@0 as ClientIP, hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - Int64(3), count(Int64(1))@4 as c]
971-
04)------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP, hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - Int64(3)], aggr=[count(Int64(1))]
972-
05)--------RepartitionExec: partitioning=Hash([ClientIP@0, hits.ClientIP - Int64(1)@1, hits.ClientIP - Int64(2)@2, hits.ClientIP - Int64(3)@3], 4), input_partitions=1
973-
06)----------AggregateExec: mode=Partial, gby=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3)], aggr=[count(Int64(1))]
974-
07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[CAST(ClientIP@7 AS Int64) as __common_expr_1, ClientIP], file_type=parquet
970+
03)----ProjectionExec: expr=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3), count(Int64(1))@2 as c]
971+
04)------ProjectionExec: expr=[CAST(ClientIP@0 AS Int64) as __common_expr_1, ClientIP@0 as ClientIP, count(Int64(1))@1 as count(Int64(1))]
972+
05)--------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP], aggr=[count(Int64(1))]
973+
06)----------RepartitionExec: partitioning=Hash([ClientIP@0], 4), input_partitions=1
974+
07)------------AggregateExec: mode=Partial, gby=[ClientIP@0 as ClientIP], aggr=[count(Int64(1))]
975+
08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[ClientIP], file_type=parquet
975976

976977
query IIIII rowsort
977978
SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10;

0 commit comments

Comments
 (0)