Skip to content

Commit 385d9db

Browse files
authored
try to remove redundant alias in expression rewriter and select (#20867)
## Which issue does this PR close? Not closes ## Rationale for this change In #20780 (comment) @alamb mentioned whether we can remove redundant alias of `count(*) AS count(*)` to `count(*)` and I tried to give this a go. ### I'm not sure about the implications at the moment it would be great to have input on this PR ## What changes are included in this PR? Main changes are in: - order_by.rs: match only top level expressions instead of recursively searching sub expressions (otherwise we may match wrong expressions) - select.rs: strip alias before comparing otherwise we dont use existing alias at all ## Are these changes tested? I've added some tests for alias. Existing tests and plan outputs changed as well you can see in the PR. ## Are there any user-facing changes? Plans will change but not sure if it has impact
1 parent 6b71523 commit 385d9db

5 files changed

Lines changed: 279 additions & 36 deletions

File tree

datafusion/core/tests/dataframe/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3004,7 +3004,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
30043004
+---------------+------------------------------------------------------------------------------------+
30053005
| plan_type | plan |
30063006
+---------------+------------------------------------------------------------------------------------+
3007-
| logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST |
3007+
| logical_plan | Sort: count(*) ASC NULLS LAST |
30083008
| | Projection: t1.b, count(Int64(1)) AS count(*) |
30093009
| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] |
30103010
| | TableScan: t1 projection=[b] |

datafusion/expr/src/expr_rewriter/order_by.rs

Lines changed: 216 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ use crate::expr::Alias;
2121
use crate::expr_rewriter::normalize_col;
2222
use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
2323

24-
use datafusion_common::tree_node::{
25-
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
26-
};
24+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
2725
use datafusion_common::{Column, Result};
2826

2927
/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
@@ -104,29 +102,27 @@ fn rewrite_in_terms_of_projection(
104102

105103
let search_col = Expr::Column(Column::new_unqualified(name));
106104

107-
// look for the column named the same as this expr
108-
let mut found = None;
109-
for proj_expr in proj_exprs {
110-
proj_expr.apply(|e| {
111-
if expr_match(&search_col, e) {
112-
found = Some(e.clone());
113-
return Ok(TreeNodeRecursion::Stop);
114-
}
115-
Ok(TreeNodeRecursion::Continue)
116-
})?;
117-
}
105+
// Search only top-level projection expressions for a match.
106+
// We intentionally avoid a recursive search (e.g. `apply`) to
107+
// prevent matching sub-expressions of composites like
108+
// `min(c2) + max(c3)` when the ORDER BY is just `min(c2)`.
109+
let found = proj_exprs
110+
.iter()
111+
.find(|proj_expr| expr_match(&search_col, proj_expr));
118112

119113
if let Some(found) = found {
114+
let (qualifier, field_name) = found.qualified_name();
115+
let col = Expr::Column(Column::new(qualifier, field_name));
120116
return Ok(Transformed::yes(match normalized_expr {
121117
Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
122-
expr: Box::new(found),
118+
expr: Box::new(col),
123119
field,
124120
}),
125121
Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast {
126-
expr: Box::new(found),
122+
expr: Box::new(col),
127123
field,
128124
}),
129-
_ => found,
125+
_ => col,
130126
}));
131127
}
132128

@@ -160,7 +156,10 @@ mod test {
160156

161157
use super::*;
162158
use crate::test::function_stub::avg;
159+
use crate::test::function_stub::count;
160+
use crate::test::function_stub::max;
163161
use crate::test::function_stub::min;
162+
use crate::test::function_stub::sum;
164163

165164
#[test]
166165
fn rewrite_sort_cols_by_agg() {
@@ -242,17 +241,14 @@ mod test {
242241
TestCase {
243242
desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
244243
input: sort(col("c1") + min(col("c2"))),
245-
// should be "c1" not t.c1
246244
expected: sort(
247245
col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")),
248246
),
249247
},
250248
TestCase {
251-
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
249+
desc: r#"avg(c3) --> "average" (column *named* "average", from alias)"#,
252250
input: sort(avg(col("c3"))),
253-
expected: sort(
254-
Expr::Column(Column::new_unqualified("avg(t.c3)")).alias("average"),
255-
),
251+
expected: sort(col("average")),
256252
},
257253
];
258254

@@ -261,6 +257,202 @@ mod test {
261257
}
262258
}
263259

260+
/// When an aggregate is aliased in the projection,
261+
/// ORDER BY on the original aggregate expression should resolve to
262+
/// a Column reference using the alias name — not leak the inner
263+
/// Alias expression node or resolve to a descendant subtree.
264+
#[test]
265+
fn rewrite_sort_resolves_alias_to_column_ref() {
266+
let plan = make_input()
267+
.aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
268+
.unwrap()
269+
.project(vec![
270+
col("c1"),
271+
min(col("c2")).alias("min_val"),
272+
max(col("c3")).alias("max_val"),
273+
])
274+
.unwrap()
275+
.build()
276+
.unwrap();
277+
278+
let cases = vec![
279+
TestCase {
280+
desc: "min(c2) with alias 'min_val' should resolve to col(min_val)",
281+
input: sort(min(col("c2"))),
282+
expected: sort(col("min_val")),
283+
},
284+
TestCase {
285+
desc: "max(c3) with alias 'max_val' should resolve to col(max_val)",
286+
input: sort(max(col("c3"))),
287+
expected: sort(col("max_val")),
288+
},
289+
];
290+
291+
for case in cases {
292+
case.run(&plan)
293+
}
294+
}
295+
296+
#[test]
297+
fn composite_proj_expr_containing_sort_col_as_subexpr() {
298+
let plan = make_input()
299+
.aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
300+
.unwrap()
301+
.project(vec![
302+
col("c1"),
303+
(min(col("c2")) + max(col("c3"))).alias("range"),
304+
min(col("c2")).alias("min_val"),
305+
max(col("c3")).alias("max_val"),
306+
])
307+
.unwrap()
308+
.build()
309+
.unwrap();
310+
311+
let cases = vec![
312+
TestCase {
313+
desc: "sort by min(c2) should resolve to col(min_val), not col(range)",
314+
input: sort(min(col("c2"))),
315+
expected: sort(col("min_val")),
316+
},
317+
TestCase {
318+
desc: "sort by max(c3) should resolve to col(max_val), not col(range)",
319+
input: sort(max(col("c3"))),
320+
expected: sort(col("max_val")),
321+
},
322+
];
323+
324+
for case in cases {
325+
case.run(&plan)
326+
}
327+
}
328+
329+
#[test]
330+
fn composite_before_standalone_should_not_shadow() {
331+
let plan = make_input()
332+
.aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))])
333+
.unwrap()
334+
.project(vec![
335+
col("c1"),
336+
(min(col("c2")) + max(col("c2"))).alias("combined"),
337+
min(col("c2")),
338+
])
339+
.unwrap()
340+
.build()
341+
.unwrap();
342+
343+
let cases = vec![TestCase {
344+
desc: "sort by min(c2) should resolve to col(min(t.c2)), not col(combined)",
345+
input: sort(min(col("c2"))),
346+
expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
347+
}];
348+
349+
for case in cases {
350+
case.run(&plan)
351+
}
352+
}
353+
354+
#[test]
355+
fn duplicate_aggregate_in_multiple_proj_exprs() {
356+
let plan = make_input()
357+
.aggregate(vec![col("c1")], vec![min(col("c2"))])
358+
.unwrap()
359+
.project(vec![
360+
col("c1"),
361+
min(col("c2")).alias("first_alias"),
362+
min(col("c2")).alias("second_alias"),
363+
])
364+
.unwrap()
365+
.build()
366+
.unwrap();
367+
368+
let cases = vec![TestCase {
369+
desc: "sort by min(c2) with two aliases picks first_alias",
370+
input: sort(min(col("c2"))),
371+
expected: sort(col("first_alias")),
372+
}];
373+
374+
for case in cases {
375+
case.run(&plan)
376+
}
377+
}
378+
379+
#[test]
380+
fn sort_agg_not_in_select_with_aliased_aggs() {
381+
let plan = make_input()
382+
.aggregate(
383+
vec![col("c1")],
384+
vec![min(col("c2")), max(col("c3")), sum(col("c3"))],
385+
)
386+
.unwrap()
387+
.project(vec![
388+
col("c1"),
389+
min(col("c2")).alias("min_val"),
390+
max(col("c3")).alias("max_val"),
391+
])
392+
.unwrap()
393+
.build()
394+
.unwrap();
395+
396+
let cases = vec![TestCase {
397+
desc: "sort by sum(c3) not in projection should not be rewritten",
398+
input: sort(sum(col("c3"))),
399+
expected: sort(sum(col("c3"))),
400+
}];
401+
402+
for case in cases {
403+
case.run(&plan)
404+
}
405+
}
406+
407+
#[test]
408+
fn cast_on_aliased_aggregate() {
409+
let plan = make_input()
410+
.aggregate(vec![col("c1")], vec![min(col("c2"))])
411+
.unwrap()
412+
.project(vec![col("c1"), min(col("c2")).alias("min_val")])
413+
.unwrap()
414+
.build()
415+
.unwrap();
416+
417+
let cases = vec![
418+
TestCase {
419+
desc: "CAST on aliased aggregate should preserve cast and resolve alias",
420+
input: sort(cast(min(col("c2")), DataType::Int64)),
421+
expected: sort(cast(col("min_val"), DataType::Int64)),
422+
},
423+
TestCase {
424+
desc: "TryCast on aliased aggregate should preserve try_cast and resolve alias",
425+
input: sort(try_cast(min(col("c2")), DataType::Int64)),
426+
expected: sort(try_cast(col("min_val"), DataType::Int64)),
427+
},
428+
];
429+
430+
for case in cases {
431+
case.run(&plan)
432+
}
433+
}
434+
435+
#[test]
436+
fn count_star_with_alias() {
437+
let plan = make_input()
438+
.aggregate(vec![col("c1")], vec![count(lit(1))])
439+
.unwrap()
440+
.project(vec![col("c1"), count(lit(1)).alias("cnt")])
441+
.unwrap()
442+
.build()
443+
.unwrap();
444+
445+
let cases = vec![TestCase {
446+
desc: "sort by count(1) should resolve to cnt alias",
447+
input: sort(count(lit(1))),
448+
expected: sort(col("cnt")),
449+
}];
450+
451+
for case in cases {
452+
case.run(&plan)
453+
}
454+
}
455+
264456
#[test]
265457
fn preserve_cast() {
266458
let plan = make_input()
@@ -275,12 +467,12 @@ mod test {
275467
TestCase {
276468
desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
277469
input: sort(cast(col("c2"), DataType::Int64)),
278-
expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
470+
expected: sort(cast(col("c2"), DataType::Int64)),
279471
},
280472
TestCase {
281473
desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
282474
input: sort(try_cast(col("c2"), DataType::Int64)),
283-
expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
475+
expected: sort(try_cast(col("c2"), DataType::Int64)),
284476
},
285477
];
286478

datafusion/sql/src/select.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,13 +1056,16 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
10561056
.iter()
10571057
.find_map(|select_expr| {
10581058
// Only consider aliased expressions
1059-
if let Expr::Alias(alias) = select_expr
1060-
&& alias.expr.as_ref() == &rewritten_expr
1061-
{
1062-
// Use the alias name
1063-
return Some(Expr::Column(Column::new_unqualified(
1064-
alias.name.clone(),
1065-
)));
1059+
if let Expr::Alias(alias) = select_expr {
1060+
let rewritten_unaliased = match &rewritten_expr {
1061+
Expr::Alias(a) => a.expr.as_ref(),
1062+
other => other,
1063+
};
1064+
if alias.expr.as_ref() == rewritten_unaliased {
1065+
return Some(Expr::Column(Column::new_unqualified(
1066+
alias.name.clone(),
1067+
)));
1068+
}
10661069
}
10671070
None
10681071
})

datafusion/sqllogictest/test_files/clickbench.slt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ query TT
205205
EXPLAIN SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC;
206206
----
207207
logical_plan
208-
01)Sort: count(*) AS count(*) DESC NULLS FIRST
208+
01)Sort: count(*) DESC NULLS FIRST
209209
02)--Projection: hits.AdvEngineID, count(Int64(1)) AS count(*)
210210
03)----Aggregate: groupBy=[[hits.AdvEngineID]], aggr=[[count(Int64(1))]]
211211
04)------SubqueryAlias: hits
@@ -431,7 +431,7 @@ query TT
431431
EXPLAIN SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10;
432432
----
433433
logical_plan
434-
01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
434+
01)Sort: count(*) DESC NULLS FIRST, fetch=10
435435
02)--Projection: hits.UserID, count(Int64(1)) AS count(*)
436436
03)----Aggregate: groupBy=[[hits.UserID]], aggr=[[count(Int64(1))]]
437437
04)------SubqueryAlias: hits
@@ -459,7 +459,7 @@ query TT
459459
EXPLAIN SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
460460
----
461461
logical_plan
462-
01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
462+
01)Sort: count(*) DESC NULLS FIRST, fetch=10
463463
02)--Projection: hits.UserID, hits.SearchPhrase, count(Int64(1)) AS count(*)
464464
03)----Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[count(Int64(1))]]
465465
04)------SubqueryAlias: hits
@@ -514,7 +514,7 @@ query TT
514514
EXPLAIN SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
515515
----
516516
logical_plan
517-
01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
517+
01)Sort: count(*) DESC NULLS FIRST, fetch=10
518518
02)--Projection: hits.UserID, date_part(Utf8("MINUTE"),to_timestamp_seconds(hits.EventTime)) AS m, hits.SearchPhrase, count(Int64(1)) AS count(*)
519519
03)----Aggregate: groupBy=[[hits.UserID, date_part(Utf8("MINUTE"), to_timestamp_seconds(hits.EventTime)), hits.SearchPhrase]], aggr=[[count(Int64(1))]]
520520
04)------SubqueryAlias: hits

0 commit comments

Comments
 (0)