Skip to content

Commit d68373e

Browse files
timsaucerclaude
andauthored
fix: grouping with alias (#21438)
## Which issue does this PR close? - Closes #21411 ## Rationale for this change When you have an alias on `grouping` function via dataframe API, you get an error. This resolves that error. ## What changes are included in this PR? Check for alias expressions in optimizer. Add unit tests. ## Are these changes tested? Unit test added, including a note on why a SQL logic test will not cover this case. ## Are there any user-facing changes? None --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bb1c8e6 commit d68373e

2 files changed

Lines changed: 90 additions & 8 deletions

File tree

datafusion/core/tests/dataframe/mod.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6851,3 +6851,50 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {
68516851

68526852
Ok(())
68536853
}
6854+
6855+
/// Regression test for https://github.com/apache/datafusion/issues/21411
6856+
/// grouping() should work when wrapped in an alias via the DataFrame API.
6857+
///
6858+
/// This bug only manifests through the DataFrame API because `.alias()` wraps
6859+
/// the `grouping()` call in an `Expr::Alias` node at the aggregate expression
6860+
/// level. The SQL planner handles aliasing separately (via projection), so the
6861+
/// `ResolveGroupingFunction` analyzer rule never sees an `Expr::Alias` wrapper
6862+
/// around the aggregate function in SQL queries — making SQL-based tests
6863+
/// insufficient to cover this case.
6864+
#[tokio::test]
6865+
async fn test_grouping_with_alias() -> Result<()> {
6866+
use datafusion_functions_aggregate::expr_fn::grouping;
6867+
6868+
let df = create_test_table("test")
6869+
.await?
6870+
.aggregate(vec![col("a")], vec![grouping(col("a")).alias("g")])?
6871+
.sort(vec![Sort::new(col("a"), true, false)])?;
6872+
6873+
let results = df.collect().await?;
6874+
6875+
let expected = [
6876+
"+-----------+---+",
6877+
"| a | g |",
6878+
"+-----------+---+",
6879+
"| 123AbcDef | 0 |",
6880+
"| CBAdef | 0 |",
6881+
"| abc123 | 0 |",
6882+
"| abcDEF | 0 |",
6883+
"+-----------+---+",
6884+
];
6885+
assert_batches_eq!(expected, &results);
6886+
6887+
// Also verify that nested aliases (e.g. .alias("x").alias("g")) work correctly
6888+
let df = create_test_table("test")
6889+
.await?
6890+
.aggregate(
6891+
vec![col("a")],
6892+
vec![grouping(col("a")).alias("x").alias("g")],
6893+
)?
6894+
.sort(vec![Sort::new(col("a"), true, false)])?;
6895+
6896+
let results = df.collect().await?;
6897+
assert_batches_eq!(expected, &results);
6898+
6899+
Ok(())
6900+
}

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ fn replace_grouping_exprs(
9797
.into_iter()
9898
.zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
9999
{
100+
let grouping_id_type = is_grouping_set
101+
.then(|| {
102+
schema
103+
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
104+
.map(|f| f.data_type().clone())
105+
})
106+
.transpose()?;
100107
match expr {
101108
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
102-
let grouping_id_type = is_grouping_set
103-
.then(|| {
104-
schema
105-
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
106-
.map(|f| f.data_type().clone())
107-
})
108-
.transpose()?;
109109
let grouping_expr = grouping_function_on_id(
110110
function,
111111
&group_expr_to_bitmap_index,
@@ -117,6 +117,24 @@ fn replace_grouping_exprs(
117117
column.name,
118118
)));
119119
}
120+
Expr::Alias(Alias {
121+
ref relation,
122+
ref name,
123+
..
124+
}) if is_grouping_function(&expr) => {
125+
let function = unwrap_alias_to_grouping_function(&expr)?;
126+
let grouping_expr = grouping_function_on_id(
127+
function,
128+
&group_expr_to_bitmap_index,
129+
grouping_id_type,
130+
)?;
131+
// Preserve the outermost user-provided alias
132+
projection_exprs.push(Expr::Alias(Alias::new(
133+
grouping_expr,
134+
relation.clone(),
135+
name.clone(),
136+
)));
137+
}
120138
_ => {
121139
projection_exprs.push(Expr::Column(column));
122140
new_agg_expr.push(expr);
@@ -155,10 +173,27 @@ fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
155173
Ok(transformed_plan)
156174
}
157175

176+
/// Recursively unwrap `Expr::Alias` nodes to reach the inner `AggregateFunction`.
177+
/// Returns an error if the innermost expression is not an `AggregateFunction`,
178+
/// which should not happen if `is_grouping_function` returned true.
179+
fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> {
180+
match expr {
181+
Expr::AggregateFunction(function) => Ok(function),
182+
Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr),
183+
_ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"),
184+
}
185+
}
186+
158187
fn is_grouping_function(expr: &Expr) -> bool {
159188
// TODO: Do something better than name here should grouping be a built
160189
// in expression?
161-
matches!(expr, Expr::AggregateFunction(AggregateFunction { func, .. }) if func.name() == "grouping")
190+
match expr {
191+
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
192+
func.name() == "grouping"
193+
}
194+
Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr),
195+
_ => false,
196+
}
162197
}
163198

164199
fn contains_grouping_function(exprs: &[Expr]) -> bool {

0 commit comments

Comments
 (0)