@@ -21,9 +21,7 @@ use crate::expr::Alias;
2121use crate :: expr_rewriter:: normalize_col;
2222use crate :: { Cast , Expr , LogicalPlan , TryCast , expr:: Sort } ;
2323
24- use datafusion_common:: tree_node:: {
25- Transformed , TransformedResult , TreeNode , TreeNodeRecursion ,
26- } ;
24+ use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
2725use datafusion_common:: { Column , Result } ;
2826
2927/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
@@ -104,29 +102,27 @@ fn rewrite_in_terms_of_projection(
104102
105103 let search_col = Expr :: Column ( Column :: new_unqualified ( name) ) ;
106104
107- // look for the column named the same as this expr
108- let mut found = None ;
109- for proj_expr in proj_exprs {
110- proj_expr. apply ( |e| {
111- if expr_match ( & search_col, e) {
112- found = Some ( e. clone ( ) ) ;
113- return Ok ( TreeNodeRecursion :: Stop ) ;
114- }
115- Ok ( TreeNodeRecursion :: Continue )
116- } ) ?;
117- }
105+ // Search only top-level projection expressions for a match.
106+ // We intentionally avoid a recursive search (e.g. `apply`) to
107+ // prevent matching sub-expressions of composites like
108+ // `min(c2) + max(c3)` when the ORDER BY is just `min(c2)`.
109+ let found = proj_exprs
110+ . iter ( )
111+ . find ( |proj_expr| expr_match ( & search_col, proj_expr) ) ;
118112
119113 if let Some ( found) = found {
114+ let ( qualifier, field_name) = found. qualified_name ( ) ;
115+ let col = Expr :: Column ( Column :: new ( qualifier, field_name) ) ;
120116 return Ok ( Transformed :: yes ( match normalized_expr {
121117 Expr :: Cast ( Cast { expr : _, field } ) => Expr :: Cast ( Cast {
122- expr : Box :: new ( found ) ,
118+ expr : Box :: new ( col ) ,
123119 field,
124120 } ) ,
125121 Expr :: TryCast ( TryCast { expr : _, field } ) => Expr :: TryCast ( TryCast {
126- expr : Box :: new ( found ) ,
122+ expr : Box :: new ( col ) ,
127123 field,
128124 } ) ,
129- _ => found ,
125+ _ => col ,
130126 } ) ) ;
131127 }
132128
@@ -160,7 +156,10 @@ mod test {
160156
161157 use super :: * ;
162158 use crate :: test:: function_stub:: avg;
159+ use crate :: test:: function_stub:: count;
160+ use crate :: test:: function_stub:: max;
163161 use crate :: test:: function_stub:: min;
162+ use crate :: test:: function_stub:: sum;
164163
165164 #[ test]
166165 fn rewrite_sort_cols_by_agg ( ) {
@@ -242,17 +241,14 @@ mod test {
242241 TestCase {
243242 desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"# ,
244243 input: sort( col( "c1" ) + min( col( "c2" ) ) ) ,
245- // should be "c1" not t.c1
246244 expected: sort(
247245 col( "c1" ) + Expr :: Column ( Column :: new_unqualified( "min(t.c2)" ) ) ,
248246 ) ,
249247 } ,
250248 TestCase {
251- desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3) ", aliased )"# ,
249+ desc: r#"avg(c3) --> "average" (column *named* "average ", from alias )"# ,
252250 input: sort( avg( col( "c3" ) ) ) ,
253- expected: sort(
254- Expr :: Column ( Column :: new_unqualified( "avg(t.c3)" ) ) . alias( "average" ) ,
255- ) ,
251+ expected: sort( col( "average" ) ) ,
256252 } ,
257253 ] ;
258254
@@ -261,6 +257,202 @@ mod test {
261257 }
262258 }
263259
260+ /// When an aggregate is aliased in the projection,
261+ /// ORDER BY on the original aggregate expression should resolve to
262+ /// a Column reference using the alias name — not leak the inner
263+ /// Alias expression node or resolve to a descendant subtree.
264+ #[ test]
265+ fn rewrite_sort_resolves_alias_to_column_ref ( ) {
266+ let plan = make_input ( )
267+ . aggregate ( vec ! [ col( "c1" ) ] , vec ! [ min( col( "c2" ) ) , max( col( "c3" ) ) ] )
268+ . unwrap ( )
269+ . project ( vec ! [
270+ col( "c1" ) ,
271+ min( col( "c2" ) ) . alias( "min_val" ) ,
272+ max( col( "c3" ) ) . alias( "max_val" ) ,
273+ ] )
274+ . unwrap ( )
275+ . build ( )
276+ . unwrap ( ) ;
277+
278+ let cases = vec ! [
279+ TestCase {
280+ desc: "min(c2) with alias 'min_val' should resolve to col(min_val)" ,
281+ input: sort( min( col( "c2" ) ) ) ,
282+ expected: sort( col( "min_val" ) ) ,
283+ } ,
284+ TestCase {
285+ desc: "max(c3) with alias 'max_val' should resolve to col(max_val)" ,
286+ input: sort( max( col( "c3" ) ) ) ,
287+ expected: sort( col( "max_val" ) ) ,
288+ } ,
289+ ] ;
290+
291+ for case in cases {
292+ case. run ( & plan)
293+ }
294+ }
295+
296+ #[ test]
297+ fn composite_proj_expr_containing_sort_col_as_subexpr ( ) {
298+ let plan = make_input ( )
299+ . aggregate ( vec ! [ col( "c1" ) ] , vec ! [ min( col( "c2" ) ) , max( col( "c3" ) ) ] )
300+ . unwrap ( )
301+ . project ( vec ! [
302+ col( "c1" ) ,
303+ ( min( col( "c2" ) ) + max( col( "c3" ) ) ) . alias( "range" ) ,
304+ min( col( "c2" ) ) . alias( "min_val" ) ,
305+ max( col( "c3" ) ) . alias( "max_val" ) ,
306+ ] )
307+ . unwrap ( )
308+ . build ( )
309+ . unwrap ( ) ;
310+
311+ let cases = vec ! [
312+ TestCase {
313+ desc: "sort by min(c2) should resolve to col(min_val), not col(range)" ,
314+ input: sort( min( col( "c2" ) ) ) ,
315+ expected: sort( col( "min_val" ) ) ,
316+ } ,
317+ TestCase {
318+ desc: "sort by max(c3) should resolve to col(max_val), not col(range)" ,
319+ input: sort( max( col( "c3" ) ) ) ,
320+ expected: sort( col( "max_val" ) ) ,
321+ } ,
322+ ] ;
323+
324+ for case in cases {
325+ case. run ( & plan)
326+ }
327+ }
328+
329+ #[ test]
330+ fn composite_before_standalone_should_not_shadow ( ) {
331+ let plan = make_input ( )
332+ . aggregate ( vec ! [ col( "c1" ) ] , vec ! [ min( col( "c2" ) ) , max( col( "c2" ) ) ] )
333+ . unwrap ( )
334+ . project ( vec ! [
335+ col( "c1" ) ,
336+ ( min( col( "c2" ) ) + max( col( "c2" ) ) ) . alias( "combined" ) ,
337+ min( col( "c2" ) ) ,
338+ ] )
339+ . unwrap ( )
340+ . build ( )
341+ . unwrap ( ) ;
342+
343+ let cases = vec ! [ TestCase {
344+ desc: "sort by min(c2) should resolve to col(min(t.c2)), not col(combined)" ,
345+ input: sort( min( col( "c2" ) ) ) ,
346+ expected: sort( Expr :: Column ( Column :: new_unqualified( "min(t.c2)" ) ) ) ,
347+ } ] ;
348+
349+ for case in cases {
350+ case. run ( & plan)
351+ }
352+ }
353+
354+ #[ test]
355+ fn duplicate_aggregate_in_multiple_proj_exprs ( ) {
356+ let plan = make_input ( )
357+ . aggregate ( vec ! [ col( "c1" ) ] , vec ! [ min( col( "c2" ) ) ] )
358+ . unwrap ( )
359+ . project ( vec ! [
360+ col( "c1" ) ,
361+ min( col( "c2" ) ) . alias( "first_alias" ) ,
362+ min( col( "c2" ) ) . alias( "second_alias" ) ,
363+ ] )
364+ . unwrap ( )
365+ . build ( )
366+ . unwrap ( ) ;
367+
368+ let cases = vec ! [ TestCase {
369+ desc: "sort by min(c2) with two aliases picks first_alias" ,
370+ input: sort( min( col( "c2" ) ) ) ,
371+ expected: sort( col( "first_alias" ) ) ,
372+ } ] ;
373+
374+ for case in cases {
375+ case. run ( & plan)
376+ }
377+ }
378+
379+ #[ test]
380+ fn sort_agg_not_in_select_with_aliased_aggs ( ) {
381+ let plan = make_input ( )
382+ . aggregate (
383+ vec ! [ col( "c1" ) ] ,
384+ vec ! [ min( col( "c2" ) ) , max( col( "c3" ) ) , sum( col( "c3" ) ) ] ,
385+ )
386+ . unwrap ( )
387+ . project ( vec ! [
388+ col( "c1" ) ,
389+ min( col( "c2" ) ) . alias( "min_val" ) ,
390+ max( col( "c3" ) ) . alias( "max_val" ) ,
391+ ] )
392+ . unwrap ( )
393+ . build ( )
394+ . unwrap ( ) ;
395+
396+ let cases = vec ! [ TestCase {
397+ desc: "sort by sum(c3) not in projection should not be rewritten" ,
398+ input: sort( sum( col( "c3" ) ) ) ,
399+ expected: sort( sum( col( "c3" ) ) ) ,
400+ } ] ;
401+
402+ for case in cases {
403+ case. run ( & plan)
404+ }
405+ }
406+
407+ #[ test]
408+ fn cast_on_aliased_aggregate ( ) {
409+ let plan = make_input ( )
410+ . aggregate ( vec ! [ col( "c1" ) ] , vec ! [ min( col( "c2" ) ) ] )
411+ . unwrap ( )
412+ . project ( vec ! [ col( "c1" ) , min( col( "c2" ) ) . alias( "min_val" ) ] )
413+ . unwrap ( )
414+ . build ( )
415+ . unwrap ( ) ;
416+
417+ let cases = vec ! [
418+ TestCase {
419+ desc: "CAST on aliased aggregate should preserve cast and resolve alias" ,
420+ input: sort( cast( min( col( "c2" ) ) , DataType :: Int64 ) ) ,
421+ expected: sort( cast( col( "min_val" ) , DataType :: Int64 ) ) ,
422+ } ,
423+ TestCase {
424+ desc: "TryCast on aliased aggregate should preserve try_cast and resolve alias" ,
425+ input: sort( try_cast( min( col( "c2" ) ) , DataType :: Int64 ) ) ,
426+ expected: sort( try_cast( col( "min_val" ) , DataType :: Int64 ) ) ,
427+ } ,
428+ ] ;
429+
430+ for case in cases {
431+ case. run ( & plan)
432+ }
433+ }
434+
435+ #[ test]
436+ fn count_star_with_alias ( ) {
437+ let plan = make_input ( )
438+ . aggregate ( vec ! [ col( "c1" ) ] , vec ! [ count( lit( 1 ) ) ] )
439+ . unwrap ( )
440+ . project ( vec ! [ col( "c1" ) , count( lit( 1 ) ) . alias( "cnt" ) ] )
441+ . unwrap ( )
442+ . build ( )
443+ . unwrap ( ) ;
444+
445+ let cases = vec ! [ TestCase {
446+ desc: "sort by count(1) should resolve to cnt alias" ,
447+ input: sort( count( lit( 1 ) ) ) ,
448+ expected: sort( col( "cnt" ) ) ,
449+ } ] ;
450+
451+ for case in cases {
452+ case. run ( & plan)
453+ }
454+ }
455+
264456 #[ test]
265457 fn preserve_cast ( ) {
266458 let plan = make_input ( )
@@ -275,12 +467,12 @@ mod test {
275467 TestCase {
276468 desc: "Cast is preserved by rewrite_sort_cols_by_aggs" ,
277469 input: sort( cast( col( "c2" ) , DataType :: Int64 ) ) ,
278- expected: sort( cast( col( "c2" ) . alias ( "c2" ) , DataType :: Int64 ) ) ,
470+ expected: sort( cast( col( "c2" ) , DataType :: Int64 ) ) ,
279471 } ,
280472 TestCase {
281473 desc: "TryCast is preserved by rewrite_sort_cols_by_aggs" ,
282474 input: sort( try_cast( col( "c2" ) , DataType :: Int64 ) ) ,
283- expected: sort( try_cast( col( "c2" ) . alias ( "c2" ) , DataType :: Int64 ) ) ,
475+ expected: sort( try_cast( col( "c2" ) , DataType :: Int64 ) ) ,
284476 } ,
285477 ] ;
286478
0 commit comments