Skip to content

Commit da4306c

Browse files
fix: fold Limit/Sort into outer SELECT when Projection claims Aggregate through them
When a Projection's `reconstruct_select_statement` reaches through a Limit or Sort to claim an Aggregate, the Limit/Sort arm would later see `already_projected` and wrap everything in a spurious derived subquery, emitting the aggregate twice. Fix: in the Projection arm, after claiming the Aggregate, detect if the direct child is a Limit or Sort. If so, fold its clauses (LIMIT/OFFSET or ORDER BY) into the current query and recurse into the Limit/Sort's child, skipping the node entirely. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 66629b7 commit da4306c

2 files changed

Lines changed: 120 additions & 81 deletions

File tree

datafusion/sql/src/unparser/plan.rs

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,14 @@ impl Unparser<'_> {
226226
/// Reconstructs a SELECT SQL statement from a logical plan by unprojecting column expressions
227227
/// found in a [Projection] node. This requires scanning the plan tree for relevant Aggregate
228228
/// and Window nodes and matching column expressions to the appropriate agg or window expressions.
229+
///
230+
/// Returns `true` if an Aggregate node was found and claimed for this SELECT.
229231
fn reconstruct_select_statement(
230232
&self,
231233
plan: &LogicalPlan,
232234
p: &Projection,
233235
select: &mut SelectBuilder,
234-
) -> Result<()> {
236+
) -> Result<bool> {
235237
let mut exprs = p.expr.clone();
236238

237239
// If an Unnest node is found within the select, find and unproject the unnest column
@@ -264,6 +266,7 @@ impl Unparser<'_> {
264266
.collect::<Result<Vec<_>>>()?,
265267
vec![],
266268
));
269+
Ok(true)
267270
}
268271
(None, Some(window)) => {
269272
let items = exprs
@@ -275,16 +278,17 @@ impl Unparser<'_> {
275278
.collect::<Result<Vec<_>>>()?;
276279

277280
select.projection(items);
281+
Ok(false)
278282
}
279283
_ => {
280284
let items = exprs
281285
.iter()
282286
.map(|e| self.select_item_to_sql(e))
283287
.collect::<Result<Vec<_>>>()?;
284288
select.projection(items);
289+
Ok(false)
285290
}
286291
}
287-
Ok(())
288292
}
289293

290294
fn derive(
@@ -423,7 +427,76 @@ impl Unparser<'_> {
423427
columns,
424428
);
425429
}
426-
self.reconstruct_select_statement(plan, p, select)?;
430+
let found_agg = self.reconstruct_select_statement(plan, p, select)?;
431+
432+
// If the Projection claimed an Aggregate by reaching through
433+
// a Limit or Sort, fold those clauses into the current query
434+
// and skip the node during recursion. Otherwise the Limit/Sort
435+
// arm would see `already_projected` and wrap everything in a
436+
// spurious derived subquery.
437+
if found_agg {
438+
if let LogicalPlan::Limit(limit) = p.input.as_ref() {
439+
if let Some(fetch) = &limit.fetch {
440+
let Some(query) = query.as_mut() else {
441+
return internal_err!(
442+
"Limit operator only valid in a statement context."
443+
);
444+
};
445+
query.limit(Some(self.expr_to_sql(fetch)?));
446+
}
447+
if let Some(skip) = &limit.skip {
448+
let Some(query) = query.as_mut() else {
449+
return internal_err!(
450+
"Offset operator only valid in a statement context."
451+
);
452+
};
453+
query.offset(Some(ast::Offset {
454+
rows: ast::OffsetRows::None,
455+
value: self.expr_to_sql(skip)?,
456+
}));
457+
}
458+
return self.select_to_sql_recursively(
459+
limit.input.as_ref(),
460+
query,
461+
select,
462+
relation,
463+
);
464+
}
465+
if let LogicalPlan::Sort(sort) = p.input.as_ref() {
466+
let Some(query_ref) = query.as_mut() else {
467+
return internal_err!(
468+
"Sort operator only valid in a statement context."
469+
);
470+
};
471+
if let Some(fetch) = sort.fetch {
472+
query_ref.limit(Some(ast::Expr::value(ast::Value::Number(
473+
fetch.to_string(),
474+
false,
475+
))));
476+
}
477+
let agg =
478+
find_agg_node_within_select(plan, select.already_projected());
479+
let sort_exprs: Vec<SortExpr> = sort
480+
.expr
481+
.iter()
482+
.map(|sort_expr| {
483+
unproject_sort_expr(
484+
sort_expr.clone(),
485+
agg,
486+
sort.input.as_ref(),
487+
)
488+
})
489+
.collect::<Result<Vec<_>>>()?;
490+
query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);
491+
return self.select_to_sql_recursively(
492+
sort.input.as_ref(),
493+
query,
494+
select,
495+
relation,
496+
);
497+
}
498+
}
499+
427500
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
428501
}
429502
LogicalPlan::Filter(filter) => {

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 44 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,52 +2919,21 @@ fn roundtrip_subquery_aggregate_with_column_alias() -> Result<(), DataFusionErro
29192919
Ok(())
29202920
}
29212921

2922-
/// Roundtrip: aggregate over a subquery projection.
2922+
/// Roundtrip: aggregate over a subquery projection with limit.
29232923
#[test]
29242924
fn roundtrip_aggregate_over_subquery() -> Result<(), DataFusionError> {
2925-
let sql = r#"SELECT __agg_0 AS "min(j1_id)", __agg_1 AS "max(j1_id)" FROM (SELECT min(j1_rename) AS __agg_0, max(j1_rename) AS __agg_1 FROM (SELECT j1_id AS j1_rename FROM j1) AS bla LIMIT 20)"#;
2926-
2927-
let statement = Parser::new(&GenericDialect {})
2928-
.try_with_sql(sql)?
2929-
.parse_statement()?;
2930-
2931-
let state = MockSessionState::default()
2932-
.with_aggregate_function(max_udaf())
2933-
.with_aggregate_function(min_udaf())
2934-
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()))
2935-
.with_expr_planner(Arc::new(NestedFunctionPlanner))
2936-
.with_expr_planner(Arc::new(FieldAccessPlanner));
2937-
2938-
let context = MockContextProvider { state };
2939-
let sql_to_rel = SqlToRel::new(&context);
2940-
let plan = sql_to_rel
2941-
.sql_statement_to_plan(statement)
2942-
.unwrap_or_else(|e| panic!("Failed to parse sql: {sql}\n{e}"));
2943-
2944-
println!("Logical plan:\n{plan}");
2945-
println!(
2946-
"\nLogical plan (verbose):\n{}",
2947-
plan.display_indent_schema()
2925+
roundtrip_statement_with_dialect_helper!(
2926+
sql: r#"SELECT __agg_0 AS "min(j1_id)", __agg_1 AS "max(j1_id)" FROM (SELECT min(j1_rename) AS __agg_0, max(j1_rename) AS __agg_1 FROM (SELECT j1_id AS j1_rename FROM j1) AS bla LIMIT 20)"#,
2927+
parser_dialect: GenericDialect {},
2928+
unparser_dialect: UnparserDefaultDialect {},
2929+
expected: @r#"SELECT __agg_0 AS "min(j1_id)", __agg_1 AS "max(j1_id)" FROM (SELECT min(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1 FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla LIMIT 20)"#,
29482930
);
2949-
2950-
let unparser = Unparser::new(&UnparserDefaultDialect {});
2951-
let roundtrip_statement = unparser.plan_to_sql(&plan)?;
2952-
let actual = &roundtrip_statement.to_string();
2953-
2954-
insta::assert_snapshot!(actual, @r#"SELECT __agg_0 AS "min(j1_id)", __agg_1 AS "max(j1_id)" FROM (SELECT min(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1 FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla LIMIT 20)"#);
29552931
Ok(())
29562932
}
29572933

2958-
/// Same as roundtrip_aggregate_over_subquery but with the Projection between
2959-
/// Limit and Aggregate removed — the aliases are inlined into the Aggregate.
2960-
///
2961-
/// Plan shape:
2962-
/// Projection: __agg_0 AS "max1(j1_id)", __agg_1 AS "max2(j1_id)"
2963-
/// Limit: fetch=20
2964-
/// Aggregate: aggr=[[max(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1]]
2965-
/// SubqueryAlias: bla
2966-
/// Projection: j1.j1_id AS j1_rename
2967-
/// TableScan: j1
2934+
/// Projection → Limit → Aggregate (aliases inlined into Aggregate, no
2935+
/// intermediate Projection). Verifies the Limit is folded into the outer
2936+
/// SELECT rather than creating a spurious derived subquery.
29682937
#[test]
29692938
fn test_unparse_aggregate_over_subquery_no_inner_proj() -> Result<()> {
29702939
let context = MockContextProvider {
@@ -2974,15 +2943,10 @@ fn test_unparse_aggregate_over_subquery_no_inner_proj() -> Result<()> {
29742943
.get_table_source(TableReference::bare("j1"))?
29752944
.schema();
29762945

2977-
// (SELECT j1_id AS j1_rename FROM j1) AS bla
29782946
let scan = table_scan(Some("j1"), &j1_schema, None)?.build()?;
2979-
let inner_subquery = LogicalPlanBuilder::from(scan)
2947+
let plan = LogicalPlanBuilder::from(scan)
29802948
.project(vec![col("j1.j1_id").alias("j1_rename")])?
29812949
.alias("bla")?
2982-
.build()?;
2983-
2984-
// Aggregate with aliases inlined (no separate Projection)
2985-
let plan = LogicalPlanBuilder::from(inner_subquery)
29862950
.aggregate(
29872951
vec![] as Vec<Expr>,
29882952
vec![
@@ -2997,29 +2961,13 @@ fn test_unparse_aggregate_over_subquery_no_inner_proj() -> Result<()> {
29972961
])?
29982962
.build()?;
29992963

3000-
println!("Logical plan:\n{plan}");
3001-
println!(
3002-
"\nLogical plan (verbose):\n{}",
3003-
plan.display_indent_schema()
3004-
);
3005-
3006-
let unparser = Unparser::default();
3007-
let sql = unparser.plan_to_sql(&plan)?.to_string();
3008-
println!("\nUnparsed SQL:\n{sql}");
3009-
2964+
let sql = Unparser::default().plan_to_sql(&plan)?.to_string();
2965+
insta::assert_snapshot!(sql, @r#"SELECT max(bla.j1_rename) AS "max1(j1_id)", max(bla.j1_rename) AS "max2(j1_id)" FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla LIMIT 20"#);
30102966
Ok(())
30112967
}
30122968

3013-
/// Same as test_unparse_aggregate_over_subquery_no_inner_proj but the outer
3014-
/// Projection references the aggregate columns WITHOUT renaming them.
3015-
/// The output column names should still match the Aggregate's aliases.
3016-
///
3017-
/// Plan shape:
3018-
/// Projection: __agg_0, __agg_1
3019-
/// Aggregate: aggr=[[max(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1]]
3020-
/// SubqueryAlias: bla
3021-
/// Projection: j1.j1_id AS j1_rename
3022-
/// TableScan: j1
2969+
/// Projection → Aggregate (aliases inlined, no rename in outer Projection).
2970+
/// Verifies the aggregate aliases are preserved as output column names.
30232971
#[test]
30242972
fn test_unparse_aggregate_no_outer_rename() -> Result<()> {
30252973
let context = MockContextProvider {
@@ -3030,12 +2978,9 @@ fn test_unparse_aggregate_no_outer_rename() -> Result<()> {
30302978
.schema();
30312979

30322980
let scan = table_scan(Some("j1"), &j1_schema, None)?.build()?;
3033-
let inner_subquery = LogicalPlanBuilder::from(scan)
2981+
let plan = LogicalPlanBuilder::from(scan)
30342982
.project(vec![col("j1.j1_id").alias("j1_rename")])?
30352983
.alias("bla")?
3036-
.build()?;
3037-
3038-
let plan = LogicalPlanBuilder::from(inner_subquery)
30392984
.aggregate(
30402985
vec![] as Vec<Expr>,
30412986
vec![
@@ -3046,16 +2991,37 @@ fn test_unparse_aggregate_no_outer_rename() -> Result<()> {
30462991
.project(vec![col("__agg_0"), col("__agg_1")])?
30472992
.build()?;
30482993

3049-
println!("Logical plan:\n{plan}");
3050-
println!(
3051-
"\nLogical plan (verbose):\n{}",
3052-
plan.display_indent_schema()
3053-
);
2994+
let sql = Unparser::default().plan_to_sql(&plan)?.to_string();
2995+
insta::assert_snapshot!(sql, @"SELECT max(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1 FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla");
2996+
Ok(())
2997+
}
30542998

3055-
let unparser = Unparser::default();
3056-
let sql = unparser.plan_to_sql(&plan)?.to_string();
3057-
println!("\nUnparsed SQL:\n{sql}");
2999+
/// Projection → Sort → Aggregate (aliases inlined into Aggregate).
3000+
/// Verifies the Sort is folded into the outer SELECT rather than creating
3001+
/// a spurious derived subquery.
3002+
#[test]
3003+
fn test_unparse_aggregate_with_sort_no_inner_proj() -> Result<()> {
3004+
let context = MockContextProvider {
3005+
state: MockSessionState::default(),
3006+
};
3007+
let j1_schema = context
3008+
.get_table_source(TableReference::bare("j1"))?
3009+
.schema();
3010+
3011+
let scan = table_scan(Some("j1"), &j1_schema, None)?.build()?;
3012+
let plan = LogicalPlanBuilder::from(scan)
3013+
.project(vec![col("j1.j1_id").alias("j1_rename")])?
3014+
.alias("bla")?
3015+
.aggregate(
3016+
vec![] as Vec<Expr>,
3017+
vec![max(col("bla.j1_rename")).alias("__agg_0")],
3018+
)?
3019+
.sort(vec![col("__agg_0").sort(true, true)])?
3020+
.project(vec![col("__agg_0").alias("max1(j1_id)")])?
3021+
.build()?;
30583022

3023+
let sql = Unparser::default().plan_to_sql(&plan)?.to_string();
3024+
insta::assert_snapshot!(sql, @r#"SELECT max(bla.j1_rename) AS "max1(j1_id)" FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla ORDER BY max(bla.j1_rename) ASC NULLS FIRST"#);
30593025
Ok(())
30603026
}
30613027

0 commit comments

Comments
 (0)