Skip to content

Commit edab328

Browse files
yashydgandhi
authored andcommitted
fix: preserve aggregate order in multi-distinct COUNT rewrite and add tests
✅ Fix projection after join so output columns match the original aggregate list when COUNT(DISTINCT …) and non-distinct aggs are interleaved (schema-compatible with mixed BI-style queries). ✅ Add internal_err guard for inconsistent aggregate index mapping. ✅ Optimizer tests: three grouped COUNT(DISTINCT), non-distinct between distincts, CAST(distinct) args, no rewrite for GROUPING SETS. ✅ SQL integration: COUNT(*) + two COUNT(DISTINCT); two GROUP BY keys with expected results. ❌ Grouping-set / filtered-distinct cases remain explicitly out of scope for this rule (covered by unchanged-plan tests where applicable). Made-with: Cursor
1 parent 1d0733d commit edab328

2 files changed

Lines changed: 265 additions & 22 deletions

File tree

datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,78 @@ async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> {
5656
);
5757
Ok(())
5858
}
59+
60+
/// `COUNT(*)` + two `COUNT(DISTINCT …)` per group (BI-style); must match non-rewritten semantics.
61+
#[tokio::test]
62+
async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> {
63+
let ctx = SessionContext::new();
64+
let schema = Arc::new(Schema::new(vec![
65+
Field::new("g", DataType::Int32, false),
66+
Field::new("b", DataType::Int32, false),
67+
Field::new("c", DataType::Int32, false),
68+
]));
69+
let batch = RecordBatch::try_new(
70+
schema.clone(),
71+
vec![
72+
Arc::new(Int32Array::from(vec![1, 1, 1])),
73+
Arc::new(Int32Array::from(vec![1, 2, 1])),
74+
Arc::new(Int32Array::from(vec![10, 20, 30])),
75+
],
76+
)?;
77+
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
78+
ctx.register_table("t", Arc::new(provider))?;
79+
80+
let sql = "SELECT g, COUNT(*) AS n, COUNT(DISTINCT b) AS db, COUNT(DISTINCT c) AS dc \
81+
FROM t GROUP BY g";
82+
let batches = ctx.sql(sql).await?.collect().await?;
83+
let out = batches_to_sort_string(&batches);
84+
85+
assert_eq!(
86+
out,
87+
"+---+---+----+----+\n\
88+
| g | n | db | dc |\n\
89+
+---+---+----+----+\n\
90+
| 1 | 3 | 2 | 3 |\n\
91+
+---+---+----+----+"
92+
);
93+
Ok(())
94+
}
95+
96+
/// Multiple `GROUP BY` keys: join must align on all keys.
97+
#[tokio::test]
98+
async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> {
99+
let ctx = SessionContext::new();
100+
let schema = Arc::new(Schema::new(vec![
101+
Field::new("g1", DataType::Int32, false),
102+
Field::new("g2", DataType::Int32, false),
103+
Field::new("b", DataType::Int32, false),
104+
Field::new("c", DataType::Int32, false),
105+
]));
106+
let batch = RecordBatch::try_new(
107+
schema.clone(),
108+
vec![
109+
Arc::new(Int32Array::from(vec![1, 1, 1])),
110+
Arc::new(Int32Array::from(vec![1, 1, 2])),
111+
Arc::new(Int32Array::from(vec![1, 1, 3])),
112+
Arc::new(Int32Array::from(vec![1, 2, 3])),
113+
],
114+
)?;
115+
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
116+
ctx.register_table("t", Arc::new(provider))?;
117+
118+
let sql = "SELECT g1, g2, COUNT(DISTINCT b) AS db, COUNT(DISTINCT c) AS dc \
119+
FROM t GROUP BY g1, g2";
120+
let batches = ctx.sql(sql).await?.collect().await?;
121+
let out = batches_to_sort_string(&batches);
122+
123+
assert_eq!(
124+
out,
125+
"+----+----+----+----+\n\
126+
| g1 | g2 | db | dc |\n\
127+
+----+----+----+----+\n\
128+
| 1 | 1 | 1 | 2 |\n\
129+
| 1 | 2 | 1 | 1 |\n\
130+
+----+----+----+----+"
131+
);
132+
Ok(())
133+
}

datafusion/optimizer/src/multi_distinct_count_rewrite.rs

Lines changed: 190 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::optimizer::ApplyOrder;
2525
use crate::{OptimizerConfig, OptimizerRule};
2626

2727
use datafusion_common::{
28-
Column, JoinConstraint, NullEquality, Result, tree_node::Transformed,
28+
Column, JoinConstraint, NullEquality, Result, internal_err, tree_node::Transformed,
2929
};
3030
use datafusion_expr::builder::project;
3131
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction};
@@ -205,10 +205,9 @@ impl OptimizerRule for MultiDistinctCountRewrite {
205205
base_aggr_exprs,
206206
)?);
207207
let base_alias = config.alias_generator().next("mdc_base");
208-
Some(Arc::new(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
209-
Arc::new(base_plan),
210-
&base_alias,
211-
)?)))
208+
Some(Arc::new(LogicalPlan::SubqueryAlias(
209+
SubqueryAlias::try_new(Arc::new(base_plan), &base_alias)?,
210+
)))
212211
};
213212

214213
let mut current = base_plan_opt;
@@ -285,9 +284,12 @@ impl OptimizerRule for MultiDistinctCountRewrite {
285284
};
286285
}
287286

288-
let current = current.expect("distinct_list non-empty implies at least one branch");
287+
let current =
288+
current.expect("distinct_list non-empty implies at least one branch");
289289
let join_schema = current.schema();
290290

291+
let base_field_count = group_size + other_list.len();
292+
291293
let mut proj_exprs: Vec<Expr> = vec![];
292294
for i in 0..group_size {
293295
let (q, f) = schema.qualified_field(i);
@@ -296,23 +298,36 @@ impl OptimizerRule for MultiDistinctCountRewrite {
296298
let c = Expr::Column(Column::new(join_q.cloned(), join_f.name()));
297299
proj_exprs.push(c.alias_qualified(q.cloned(), orig_name));
298300
}
299-
for (field_idx, (_, schema_aggr_idx)) in other_list.iter().enumerate() {
300-
let (q, f) = schema.qualified_field(*schema_aggr_idx);
301+
// Preserve original aggregate column order (distinct and non-distinct may be interleaved).
302+
for aggr_i in 0..aggr_expr.len() {
303+
let schema_idx = group_size + aggr_i;
304+
let (q, f) = schema.qualified_field(schema_idx);
301305
let orig_name = f.name();
302-
let join_idx = group_size + field_idx;
303-
let (join_q, join_f) = join_schema.qualified_field(join_idx);
304-
let c = Expr::Column(Column::new(join_q.cloned(), join_f.name()));
305-
proj_exprs.push(c.alias_qualified(q.cloned(), orig_name));
306-
}
307-
let base_field_count = group_size + other_list.len();
308-
for (idx, (_, schema_aggr_idx, _)) in distinct_list.iter().enumerate() {
309-
let (q, f) = schema.qualified_field(*schema_aggr_idx);
310-
let orig_name = f.name();
311-
let branch_start_idx = base_field_count + idx * (group_size + 1);
312-
let branch_aggr_idx = branch_start_idx + group_size;
313-
let (join_q, join_f) = join_schema.qualified_field(branch_aggr_idx);
314-
let c = Expr::Column(Column::new(join_q.cloned(), join_f.name()));
315-
proj_exprs.push(c.alias_qualified(q.cloned(), orig_name));
306+
307+
if let Some((dist_idx, (_, _, _))) = distinct_list
308+
.iter()
309+
.enumerate()
310+
.find(|(_, (_, idx, _))| *idx == schema_idx)
311+
{
312+
let branch_start_idx = base_field_count + dist_idx * (group_size + 1);
313+
let branch_aggr_idx = branch_start_idx + group_size;
314+
let (join_q, join_f) = join_schema.qualified_field(branch_aggr_idx);
315+
let c = Expr::Column(Column::new(join_q.cloned(), join_f.name()));
316+
proj_exprs.push(c.alias_qualified(q.cloned(), orig_name));
317+
} else if let Some((other_idx, _)) = other_list
318+
.iter()
319+
.enumerate()
320+
.find(|(_, (_, idx))| *idx == schema_idx)
321+
{
322+
let join_idx = group_size + other_idx;
323+
let (join_q, join_f) = join_schema.qualified_field(join_idx);
324+
let c = Expr::Column(Column::new(join_q.cloned(), join_f.name()));
325+
proj_exprs.push(c.alias_qualified(q.cloned(), orig_name));
326+
} else {
327+
return internal_err!(
328+
"aggregate index {aggr_i} (schema index {schema_idx}) is neither distinct nor other"
329+
);
330+
}
316331
}
317332

318333
let out = project((*current).clone(), proj_exprs)?;
@@ -327,8 +342,13 @@ mod tests {
327342
use crate::OptimizerContext;
328343
use crate::OptimizerRule;
329344
use crate::test::*;
345+
use arrow::datatypes::DataType;
346+
use datafusion_expr::GroupingSet;
330347
use datafusion_expr::LogicalPlan;
348+
use datafusion_expr::expr_fn::cast;
349+
use datafusion_expr::logical_plan::Aggregate;
331350
use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
351+
use datafusion_expr::{Expr, col};
332352
use datafusion_functions_aggregate::expr_fn::{count, count_distinct};
333353

334354
fn optimize_with_rule(
@@ -403,6 +423,53 @@ mod tests {
403423
Ok(())
404424
}
405425

426+
/// Grouped query with multiple `COUNT(DISTINCT …)` **and** non-distinct aggregates (typical BI).
427+
/// Non-distinct aggs live in `mdc_base`; each distinct column gets a branch + join on keys.
428+
#[test]
429+
fn rewrites_two_count_distinct_with_non_distinct_count() -> Result<()> {
430+
let table_scan = test_table_scan()?;
431+
let plan = LogicalPlanBuilder::from(table_scan)
432+
.aggregate(
433+
vec![col("a")],
434+
vec![
435+
count_distinct(col("b")),
436+
count_distinct(col("c")),
437+
count(col("a")),
438+
],
439+
)?
440+
.build()?;
441+
442+
let optimized =
443+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
444+
let s = optimized.display_indent_schema().to_string();
445+
assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}");
446+
assert!(
447+
s.contains("SubqueryAlias: mdc_base"),
448+
"expected base aggregate for non-distinct aggs, got:\n{s}"
449+
);
450+
Ok(())
451+
}
452+
453+
#[test]
454+
fn does_not_rewrite_two_count_distinct_same_column() -> Result<()> {
455+
let table_scan = test_table_scan()?;
456+
let plan = LogicalPlanBuilder::from(table_scan)
457+
.aggregate(
458+
vec![col("a")],
459+
vec![
460+
count_distinct(col("b")).alias("cd1"),
461+
count_distinct(col("b")).alias("cd2"),
462+
],
463+
)?
464+
.build()?;
465+
let before = plan.display_indent_schema().to_string();
466+
let optimized =
467+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
468+
let after = optimized.display_indent_schema().to_string();
469+
assert_eq!(before, after);
470+
Ok(())
471+
}
472+
406473
#[test]
407474
fn does_not_rewrite_single_count_distinct() -> Result<()> {
408475
let table_scan = test_table_scan()?;
@@ -417,6 +484,107 @@ mod tests {
417484
Ok(())
418485
}
419486

487+
#[test]
488+
fn rewrites_three_count_distinct_grouped() -> Result<()> {
489+
let table_scan = test_table_scan()?;
490+
let plan = LogicalPlanBuilder::from(table_scan)
491+
.aggregate(
492+
vec![col("a")],
493+
vec![
494+
count_distinct(col("b")),
495+
count_distinct(col("c")),
496+
count_distinct(col("a")),
497+
],
498+
)?
499+
.build()?;
500+
501+
let optimized =
502+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
503+
let s = optimized.display_indent_schema().to_string();
504+
assert!(
505+
s.matches("Inner Join").count() >= 2,
506+
"expected two joins for three branches, got:\n{s}"
507+
);
508+
assert!(
509+
s.contains("SubqueryAlias: mdc_base"),
510+
"expected base aggregate, got:\n{s}"
511+
);
512+
Ok(())
513+
}
514+
515+
#[test]
516+
fn rewrites_interleaved_non_distinct_between_distincts() -> Result<()> {
517+
let table_scan = test_table_scan()?;
518+
let plan = LogicalPlanBuilder::from(table_scan)
519+
.aggregate(
520+
vec![col("a")],
521+
vec![
522+
count_distinct(col("b")),
523+
count(col("a")),
524+
count_distinct(col("c")),
525+
],
526+
)?
527+
.build()?;
528+
529+
let optimized =
530+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
531+
let s = optimized.display_indent_schema().to_string();
532+
assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}");
533+
assert!(
534+
s.contains("SubqueryAlias: mdc_base"),
535+
"expected base for middle count(a), got:\n{s}"
536+
);
537+
Ok(())
538+
}
539+
540+
#[test]
541+
fn rewrites_count_distinct_on_cast_exprs() -> Result<()> {
542+
let table_scan = test_table_scan()?;
543+
let plan = LogicalPlanBuilder::from(table_scan)
544+
.aggregate(
545+
vec![col("a")],
546+
vec![
547+
count_distinct(cast(col("b"), DataType::Int64)),
548+
count_distinct(cast(col("c"), DataType::Int64)),
549+
],
550+
)?
551+
.build()?;
552+
553+
let optimized =
554+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
555+
let s = optimized.display_indent_schema().to_string();
556+
assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}");
557+
assert!(
558+
s.contains("Filter: CAST(test.b AS Int64) IS NOT NULL"),
559+
"expected null filter on cast(b), got:\n{s}"
560+
);
561+
assert!(
562+
s.contains("Filter: CAST(test.c AS Int64) IS NOT NULL"),
563+
"expected null filter on cast(c), got:\n{s}"
564+
);
565+
Ok(())
566+
}
567+
568+
#[test]
569+
fn does_not_rewrite_grouping_sets_multi_distinct() -> Result<()> {
570+
let table_scan = test_table_scan()?;
571+
let group_expr = vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![vec![
572+
col("a"),
573+
]]))];
574+
let aggr_expr = vec![count_distinct(col("b")), count_distinct(col("c"))];
575+
let plan = LogicalPlan::Aggregate(Aggregate::try_new(
576+
Arc::new(table_scan),
577+
group_expr,
578+
aggr_expr,
579+
)?);
580+
let before = plan.display_indent_schema().to_string();
581+
let optimized =
582+
optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?;
583+
let after = optimized.display_indent_schema().to_string();
584+
assert_eq!(before, after);
585+
Ok(())
586+
}
587+
420588
#[test]
421589
fn does_not_rewrite_mixed_agg() -> Result<()> {
422590
let table_scan = test_table_scan()?;

0 commit comments

Comments
 (0)