diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 6e22aeab61089..26de1e23b38e1 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6851,3 +6851,50 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +/// Regression test for https://github.com/apache/datafusion/issues/21411 +/// grouping() should work when wrapped in an alias via the DataFrame API. +/// +/// This bug only manifests through the DataFrame API because `.alias()` wraps +/// the `grouping()` call in an `Expr::Alias` node at the aggregate expression +/// level. The SQL planner handles aliasing separately (via projection), so the +/// `ResolveGroupingFunction` analyzer rule never sees an `Expr::Alias` wrapper +/// around the aggregate function in SQL queries — making SQL-based tests +/// insufficient to cover this case. +#[tokio::test] +async fn test_grouping_with_alias() -> Result<()> { + use datafusion_functions_aggregate::expr_fn::grouping; + + let df = create_test_table("test") + .await? + .aggregate(vec![col("a")], vec![grouping(col("a")).alias("g")])? + .sort(vec![Sort::new(col("a"), true, false)])?; + + let results = df.collect().await?; + + let expected = [ + "+-----------+---+", + "| a | g |", + "+-----------+---+", + "| 123AbcDef | 0 |", + "| CBAdef | 0 |", + "| abc123 | 0 |", + "| abcDEF | 0 |", + "+-----------+---+", + ]; + assert_batches_eq!(expected, &results); + + // Also verify that nested aliases (e.g. .alias("x").alias("g")) work correctly + let df = create_test_table("test") + .await? + .aggregate( + vec![col("a")], + vec![grouping(col("a")).alias("x").alias("g")], + )? + .sort(vec![Sort::new(col("a"), true, false)])?; + + let results = df.collect().await?; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index c12d7fd2ec2f6..95649ab8286b7 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -97,15 +97,15 @@ fn replace_grouping_exprs( .into_iter() .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) { + let grouping_id_type = is_grouping_set + .then(|| { + schema + .field_with_name(None, Aggregate::INTERNAL_GROUPING_ID) + .map(|f| f.data_type().clone()) + }) + .transpose()?; match expr { Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { - let grouping_id_type = is_grouping_set - .then(|| { - schema - .field_with_name(None, Aggregate::INTERNAL_GROUPING_ID) - .map(|f| f.data_type().clone()) - }) - .transpose()?; let grouping_expr = grouping_function_on_id( function, &group_expr_to_bitmap_index, @@ -117,6 +117,24 @@ fn replace_grouping_exprs( column.name, ))); } + Expr::Alias(Alias { + ref relation, + ref name, + .. + }) if is_grouping_function(&expr) => { + let function = unwrap_alias_to_grouping_function(&expr)?; + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + grouping_id_type, + )?; + // Preserve the outermost user-provided alias + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + relation.clone(), + name.clone(), + ))); + } _ => { projection_exprs.push(Expr::Column(column)); new_agg_expr.push(expr); @@ -155,10 +173,27 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Ok(transformed_plan) } +/// Recursively unwrap `Expr::Alias` nodes to reach the inner `AggregateFunction`. +/// Returns an error if the innermost expression is not an `AggregateFunction`, +/// which should not happen if `is_grouping_function` returned true. +fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> { + match expr { + Expr::AggregateFunction(function) => Ok(function), + Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr), + _ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"), + } +} + fn is_grouping_function(expr: &Expr) -> bool { // TODO: Do something better than name here should grouping be a built // in expression? - matches!(expr, Expr::AggregateFunction(AggregateFunction { func, .. }) if func.name() == "grouping") + match expr { + Expr::AggregateFunction(AggregateFunction { func, .. }) => { + func.name() == "grouping" + } + Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr), + _ => false, + } } fn contains_grouping_function(exprs: &[Expr]) -> bool {