Skip to content

Commit 9a5c61d

Browse files
authored
Substrait: handle identical grouping expressions (#16189)
* Substrait: handle identical grouping expressions * fix
1 parent 795988d commit 9a5c61d

3 files changed

Lines changed: 78 additions & 1 deletion

File tree

datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::logical_plan::consumer::SubstraitConsumer;
1918
use crate::logical_plan::consumer::{from_substrait_agg_func, from_substrait_sorts};
19+
use crate::logical_plan::consumer::{NameTracker, SubstraitConsumer};
2020
use datafusion::common::{not_impl_err, DFSchemaRef};
2121
use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder};
2222
use substrait::proto::aggregate_function::AggregationInvocation;
@@ -113,6 +113,14 @@ pub async fn from_aggregate_rel(
113113
};
114114
aggr_exprs.push(agg_func?.as_ref().clone());
115115
}
116+
117+
// Ensure that all expressions have a unique name
118+
let mut name_tracker = NameTracker::new();
119+
let group_exprs = group_exprs
120+
.iter()
121+
.map(|e| name_tracker.get_uniquely_named_expr(e.clone()))
122+
.collect::<Result<Vec<Expr>, _>>()?;
123+
116124
input.aggregate(group_exprs, aggr_exprs)?.build()
117125
} else {
118126
not_impl_err!("Aggregate without an input is not valid")

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,22 @@ async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
854854
Ok(())
855855
}
856856

857+
#[tokio::test]
858+
async fn aggregate_identical_grouping_expressions() -> Result<()> {
859+
let proto_plan =
860+
read_json("tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json");
861+
862+
let plan = generate_plan_from_substrait(proto_plan).await?;
863+
assert_snapshot!(
864+
plan,
865+
@r#"
866+
Aggregate: groupBy=[[Int32(1) AS grouping_col_1, Int32(1) AS grouping_col_2]], aggr=[[]]
867+
TableScan: data projection=[]
868+
"#
869+
);
870+
Ok(())
871+
}
872+
857873
#[tokio::test]
858874
async fn simple_intersect_consume() -> Result<()> {
859875
let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json");
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{
2+
"extensionUris": [],
3+
"extensions": [],
4+
"relations": [
5+
{
6+
"root": {
7+
"input": {
8+
"aggregate": {
9+
"input": {
10+
"read": {
11+
"common": {
12+
"direct": {}
13+
},
14+
"baseSchema": {
15+
"names": [],
16+
"struct": {
17+
"types": [],
18+
"nullability": "NULLABILITY_NULLABLE"
19+
}
20+
},
21+
"namedTable": {
22+
"names": ["data"]
23+
}
24+
}
25+
},
26+
"groupings": [
27+
{
28+
"groupingExpressions": [
29+
{
30+
"literal": {
31+
"i32": 1
32+
}
33+
},
34+
{
35+
"literal": {
36+
"i32": 1
37+
}
38+
}
39+
]
40+
}
41+
],
42+
"measures": []
43+
}
44+
},
45+
"names": ["grouping_col_1", "grouping_col_2"]
46+
}
47+
}
48+
],
49+
"version": {
50+
"minorNumber": 54,
51+
"producer": "manual"
52+
}
53+
}

0 commit comments

Comments
 (0)