Skip to content

Commit cc8e1e6

Browse files
author
yash
committed
fix: global multi-COUNT(DISTINCT) rewrite without invalid empty Aggregate
✅ Omit base Aggregate when GROUP BY is empty and only COUNT(DISTINCT) branches exist (matches clickbench extended global queries). ✅ First distinct branch seeds the plan; subsequent branches join (empty keys → Cross Join in plan). ✅ Add rewrites_global_three_count_distinct unit test. ❌ Previous shape could error: Aggregate with no grouping and no aggregate expressions.
1 parent 59fa290 commit cc8e1e6

1 file changed

Lines changed: 77 additions & 36 deletions

File tree

datafusion/optimizer/src/multi_distinct_count_rewrite.rs

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,26 @@ impl OptimizerRule for MultiDistinctCountRewrite {
192192
e.clone().alias_qualified(q.cloned(), f.name())
193193
})
194194
.collect();
195-
let base_plan = LogicalPlan::Aggregate(Aggregate::try_new(
196-
Arc::clone(&input),
197-
group_expr.clone(),
198-
base_aggr_exprs,
199-
)?);
200195

201-
let base_alias = config.alias_generator().next("mdc_base");
202-
let base_aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
203-
Arc::new(base_plan),
204-
&base_alias,
205-
)?);
206-
207-
let mut current = Arc::new(base_aliased);
196+
// `Aggregate` must have at least one of grouping exprs or aggregate exprs.
197+
// Global multi-`COUNT(DISTINCT)` (no GROUP BY, no other aggs) has neither — skip a base node.
198+
let base_plan_opt: Option<Arc<LogicalPlan>> =
199+
if group_expr.is_empty() && other_list.is_empty() {
200+
None
201+
} else {
202+
let base_plan = LogicalPlan::Aggregate(Aggregate::try_new(
203+
Arc::clone(&input),
204+
group_expr.clone(),
205+
base_aggr_exprs,
206+
)?);
207+
let base_alias = config.alias_generator().next("mdc_base");
208+
Some(Arc::new(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
209+
Arc::new(base_plan),
210+
&base_alias,
211+
)?)))
212+
};
213+
214+
let mut current = base_plan_opt;
208215

209216
for (distinct_arg, schema_aggr_idx, _) in distinct_list.iter() {
210217
// COUNT(DISTINCT x) ignores NULLs; filter before grouping by x.
@@ -247,32 +254,38 @@ impl OptimizerRule for MultiDistinctCountRewrite {
247254
&alias_name,
248255
)?);
249256

250-
let left_schema = current.schema();
251-
let right_schema = branch_aliased.schema();
252-
let join_keys: Vec<(Expr, Expr)> = (0..group_size)
253-
.map(|i| {
254-
let (lq, lf) = left_schema.qualified_field(i);
255-
let (rq, rf) = right_schema.qualified_field(i);
256-
(
257-
Expr::Column(Column::new(lq.cloned(), lf.name())),
258-
Expr::Column(Column::new(rq.cloned(), rf.name())),
259-
)
260-
})
261-
.collect();
262-
263-
let join = Join::try_new(
264-
current,
265-
Arc::new(branch_aliased),
266-
join_keys,
267-
None,
268-
JoinType::Inner,
269-
JoinConstraint::On,
270-
NullEquality::NullEqualsNothing,
271-
false,
272-
)?;
273-
current = Arc::new(LogicalPlan::Join(join));
257+
current = match current {
258+
None => Some(Arc::new(branch_aliased)),
259+
Some(prev) => {
260+
let left_schema = prev.schema();
261+
let right_schema = branch_aliased.schema();
262+
let join_keys: Vec<(Expr, Expr)> = (0..group_size)
263+
.map(|i| {
264+
let (lq, lf) = left_schema.qualified_field(i);
265+
let (rq, rf) = right_schema.qualified_field(i);
266+
(
267+
Expr::Column(Column::new(lq.cloned(), lf.name())),
268+
Expr::Column(Column::new(rq.cloned(), rf.name())),
269+
)
270+
})
271+
.collect();
272+
273+
let join = Join::try_new(
274+
prev,
275+
Arc::new(branch_aliased),
276+
join_keys,
277+
None,
278+
JoinType::Inner,
279+
JoinConstraint::On,
280+
NullEquality::NullEqualsNothing,
281+
false,
282+
)?;
283+
Some(Arc::new(LogicalPlan::Join(join)))
284+
}
285+
};
274286
}
275287

288+
let current = current.expect("distinct_list non-empty implies at least one branch");
276289
let join_schema = current.schema();
277290

278291
let mut proj_exprs: Vec<Expr> = vec![];
@@ -362,6 +375,34 @@ mod tests {
362375
Ok(())
363376
}
364377

378+
#[test]
379+
fn rewrites_global_three_count_distinct() -> Result<()> {
380+
let table_scan = test_table_scan()?;
381+
let plan = LogicalPlanBuilder::from(table_scan)
382+
.aggregate(
383+
Vec::<Expr>::new(),
384+
vec![
385+
count_distinct(col("a")),
386+
count_distinct(col("b")),
387+
count_distinct(col("c")),
388+
],
389+
)?
390+
.build()?;
391+
392+
let optimized =
393+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
394+
let s = optimized.display_indent_schema().to_string();
395+
assert!(
396+
s.contains("Cross Join") || s.contains("Inner Join"),
397+
"expected join rewrite for global multi-distinct, got:\n{s}"
398+
);
399+
assert!(
400+
!s.contains("mdc_base"),
401+
"global-only rewrite should not use mdc_base, got:\n{s}"
402+
);
403+
Ok(())
404+
}
405+
365406
#[test]
366407
fn does_not_rewrite_single_count_distinct() -> Result<()> {
367408
let table_scan = test_table_scan()?;

0 commit comments

Comments
 (0)