@@ -25,7 +25,7 @@ use crate::optimizer::ApplyOrder;
2525use crate :: { OptimizerConfig , OptimizerRule } ;
2626
2727use datafusion_common:: {
28- Column , JoinConstraint , NullEquality , Result , tree_node:: Transformed ,
28+ Column , JoinConstraint , NullEquality , Result , internal_err , tree_node:: Transformed ,
2929} ;
3030use datafusion_expr:: builder:: project;
3131use datafusion_expr:: expr:: { AggregateFunction , AggregateFunctionParams , ScalarFunction } ;
@@ -205,10 +205,9 @@ impl OptimizerRule for MultiDistinctCountRewrite {
205205 base_aggr_exprs,
206206 ) ?) ;
207207 let base_alias = config. alias_generator ( ) . next ( "mdc_base" ) ;
208- Some ( Arc :: new ( LogicalPlan :: SubqueryAlias ( SubqueryAlias :: try_new (
209- Arc :: new ( base_plan) ,
210- & base_alias,
211- ) ?) ) )
208+ Some ( Arc :: new ( LogicalPlan :: SubqueryAlias (
209+ SubqueryAlias :: try_new ( Arc :: new ( base_plan) , & base_alias) ?,
210+ ) ) )
212211 } ;
213212
214213 let mut current = base_plan_opt;
@@ -285,9 +284,12 @@ impl OptimizerRule for MultiDistinctCountRewrite {
285284 } ;
286285 }
287286
288- let current = current. expect ( "distinct_list non-empty implies at least one branch" ) ;
287+ let current =
288+ current. expect ( "distinct_list non-empty implies at least one branch" ) ;
289289 let join_schema = current. schema ( ) ;
290290
291+ let base_field_count = group_size + other_list. len ( ) ;
292+
291293 let mut proj_exprs: Vec < Expr > = vec ! [ ] ;
292294 for i in 0 ..group_size {
293295 let ( q, f) = schema. qualified_field ( i) ;
@@ -296,23 +298,36 @@ impl OptimizerRule for MultiDistinctCountRewrite {
296298 let c = Expr :: Column ( Column :: new ( join_q. cloned ( ) , join_f. name ( ) ) ) ;
297299 proj_exprs. push ( c. alias_qualified ( q. cloned ( ) , orig_name) ) ;
298300 }
299- for ( field_idx, ( _, schema_aggr_idx) ) in other_list. iter ( ) . enumerate ( ) {
300- let ( q, f) = schema. qualified_field ( * schema_aggr_idx) ;
301+ // Preserve original aggregate column order (distinct and non-distinct may be interleaved).
302+ for aggr_i in 0 ..aggr_expr. len ( ) {
303+ let schema_idx = group_size + aggr_i;
304+ let ( q, f) = schema. qualified_field ( schema_idx) ;
301305 let orig_name = f. name ( ) ;
302- let join_idx = group_size + field_idx;
303- let ( join_q, join_f) = join_schema. qualified_field ( join_idx) ;
304- let c = Expr :: Column ( Column :: new ( join_q. cloned ( ) , join_f. name ( ) ) ) ;
305- proj_exprs. push ( c. alias_qualified ( q. cloned ( ) , orig_name) ) ;
306- }
307- let base_field_count = group_size + other_list. len ( ) ;
308- for ( idx, ( _, schema_aggr_idx, _) ) in distinct_list. iter ( ) . enumerate ( ) {
309- let ( q, f) = schema. qualified_field ( * schema_aggr_idx) ;
310- let orig_name = f. name ( ) ;
311- let branch_start_idx = base_field_count + idx * ( group_size + 1 ) ;
312- let branch_aggr_idx = branch_start_idx + group_size;
313- let ( join_q, join_f) = join_schema. qualified_field ( branch_aggr_idx) ;
314- let c = Expr :: Column ( Column :: new ( join_q. cloned ( ) , join_f. name ( ) ) ) ;
315- proj_exprs. push ( c. alias_qualified ( q. cloned ( ) , orig_name) ) ;
306+
307+ if let Some ( ( dist_idx, ( _, _, _) ) ) = distinct_list
308+ . iter ( )
309+ . enumerate ( )
310+ . find ( |( _, ( _, idx, _) ) | * idx == schema_idx)
311+ {
312+ let branch_start_idx = base_field_count + dist_idx * ( group_size + 1 ) ;
313+ let branch_aggr_idx = branch_start_idx + group_size;
314+ let ( join_q, join_f) = join_schema. qualified_field ( branch_aggr_idx) ;
315+ let c = Expr :: Column ( Column :: new ( join_q. cloned ( ) , join_f. name ( ) ) ) ;
316+ proj_exprs. push ( c. alias_qualified ( q. cloned ( ) , orig_name) ) ;
317+ } else if let Some ( ( other_idx, _) ) = other_list
318+ . iter ( )
319+ . enumerate ( )
320+ . find ( |( _, ( _, idx) ) | * idx == schema_idx)
321+ {
322+ let join_idx = group_size + other_idx;
323+ let ( join_q, join_f) = join_schema. qualified_field ( join_idx) ;
324+ let c = Expr :: Column ( Column :: new ( join_q. cloned ( ) , join_f. name ( ) ) ) ;
325+ proj_exprs. push ( c. alias_qualified ( q. cloned ( ) , orig_name) ) ;
326+ } else {
327+ return internal_err ! (
328+ "aggregate index {aggr_i} (schema index {schema_idx}) is neither distinct nor other"
329+ ) ;
330+ }
316331 }
317332
318333 let out = project ( ( * current) . clone ( ) , proj_exprs) ?;
@@ -327,8 +342,13 @@ mod tests {
327342 use crate :: OptimizerContext ;
328343 use crate :: OptimizerRule ;
329344 use crate :: test:: * ;
345+ use arrow:: datatypes:: DataType ;
346+ use datafusion_expr:: GroupingSet ;
330347 use datafusion_expr:: LogicalPlan ;
348+ use datafusion_expr:: expr_fn:: cast;
349+ use datafusion_expr:: logical_plan:: Aggregate ;
331350 use datafusion_expr:: logical_plan:: builder:: LogicalPlanBuilder ;
351+ use datafusion_expr:: { Expr , col} ;
332352 use datafusion_functions_aggregate:: expr_fn:: { count, count_distinct} ;
333353
334354 fn optimize_with_rule (
@@ -403,6 +423,53 @@ mod tests {
403423 Ok ( ( ) )
404424 }
405425
426+ /// Grouped query with multiple `COUNT(DISTINCT …)` **and** non-distinct aggregates (typical BI).
427+ /// Non-distinct aggs live in `mdc_base`; each distinct column gets a branch + join on keys.
428+ #[ test]
429+ fn rewrites_two_count_distinct_with_non_distinct_count ( ) -> Result < ( ) > {
430+ let table_scan = test_table_scan ( ) ?;
431+ let plan = LogicalPlanBuilder :: from ( table_scan)
432+ . aggregate (
433+ vec ! [ col( "a" ) ] ,
434+ vec ! [
435+ count_distinct( col( "b" ) ) ,
436+ count_distinct( col( "c" ) ) ,
437+ count( col( "a" ) ) ,
438+ ] ,
439+ ) ?
440+ . build ( ) ?;
441+
442+ let optimized =
443+ optimize_with_rule ( plan, Arc :: new ( MultiDistinctCountRewrite :: new ( ) ) ) ?;
444+ let s = optimized. display_indent_schema ( ) . to_string ( ) ;
445+ assert ! ( s. contains( "Inner Join" ) , "expected join rewrite, got:\n {s}" ) ;
446+ assert ! (
447+ s. contains( "SubqueryAlias: mdc_base" ) ,
448+ "expected base aggregate for non-distinct aggs, got:\n {s}"
449+ ) ;
450+ Ok ( ( ) )
451+ }
452+
453+ #[ test]
454+ fn does_not_rewrite_two_count_distinct_same_column ( ) -> Result < ( ) > {
455+ let table_scan = test_table_scan ( ) ?;
456+ let plan = LogicalPlanBuilder :: from ( table_scan)
457+ . aggregate (
458+ vec ! [ col( "a" ) ] ,
459+ vec ! [
460+ count_distinct( col( "b" ) ) . alias( "cd1" ) ,
461+ count_distinct( col( "b" ) ) . alias( "cd2" ) ,
462+ ] ,
463+ ) ?
464+ . build ( ) ?;
465+ let before = plan. display_indent_schema ( ) . to_string ( ) ;
466+ let optimized =
467+ optimize_with_rule ( plan, Arc :: new ( MultiDistinctCountRewrite :: new ( ) ) ) ?;
468+ let after = optimized. display_indent_schema ( ) . to_string ( ) ;
469+ assert_eq ! ( before, after) ;
470+ Ok ( ( ) )
471+ }
472+
406473 #[ test]
407474 fn does_not_rewrite_single_count_distinct ( ) -> Result < ( ) > {
408475 let table_scan = test_table_scan ( ) ?;
@@ -417,6 +484,107 @@ mod tests {
417484 Ok ( ( ) )
418485 }
419486
487+ #[ test]
488+ fn rewrites_three_count_distinct_grouped ( ) -> Result < ( ) > {
489+ let table_scan = test_table_scan ( ) ?;
490+ let plan = LogicalPlanBuilder :: from ( table_scan)
491+ . aggregate (
492+ vec ! [ col( "a" ) ] ,
493+ vec ! [
494+ count_distinct( col( "b" ) ) ,
495+ count_distinct( col( "c" ) ) ,
496+ count_distinct( col( "a" ) ) ,
497+ ] ,
498+ ) ?
499+ . build ( ) ?;
500+
501+ let optimized =
502+ optimize_with_rule ( plan, Arc :: new ( MultiDistinctCountRewrite :: new ( ) ) ) ?;
503+ let s = optimized. display_indent_schema ( ) . to_string ( ) ;
504+ assert ! (
505+ s. matches( "Inner Join" ) . count( ) >= 2 ,
506+ "expected two joins for three branches, got:\n {s}"
507+ ) ;
508+ assert ! (
509+ s. contains( "SubqueryAlias: mdc_base" ) ,
510+ "expected base aggregate, got:\n {s}"
511+ ) ;
512+ Ok ( ( ) )
513+ }
514+
515+ #[ test]
516+ fn rewrites_interleaved_non_distinct_between_distincts ( ) -> Result < ( ) > {
517+ let table_scan = test_table_scan ( ) ?;
518+ let plan = LogicalPlanBuilder :: from ( table_scan)
519+ . aggregate (
520+ vec ! [ col( "a" ) ] ,
521+ vec ! [
522+ count_distinct( col( "b" ) ) ,
523+ count( col( "a" ) ) ,
524+ count_distinct( col( "c" ) ) ,
525+ ] ,
526+ ) ?
527+ . build ( ) ?;
528+
529+ let optimized =
530+ optimize_with_rule ( plan, Arc :: new ( MultiDistinctCountRewrite :: new ( ) ) ) ?;
531+ let s = optimized. display_indent_schema ( ) . to_string ( ) ;
532+ assert ! ( s. contains( "Inner Join" ) , "expected join rewrite, got:\n {s}" ) ;
533+ assert ! (
534+ s. contains( "SubqueryAlias: mdc_base" ) ,
535+ "expected base for middle count(a), got:\n {s}"
536+ ) ;
537+ Ok ( ( ) )
538+ }
539+
540+ #[ test]
541+ fn rewrites_count_distinct_on_cast_exprs ( ) -> Result < ( ) > {
542+ let table_scan = test_table_scan ( ) ?;
543+ let plan = LogicalPlanBuilder :: from ( table_scan)
544+ . aggregate (
545+ vec ! [ col( "a" ) ] ,
546+ vec ! [
547+ count_distinct( cast( col( "b" ) , DataType :: Int64 ) ) ,
548+ count_distinct( cast( col( "c" ) , DataType :: Int64 ) ) ,
549+ ] ,
550+ ) ?
551+ . build ( ) ?;
552+
553+ let optimized =
554+ optimize_with_rule ( plan, Arc :: new ( MultiDistinctCountRewrite :: new ( ) ) ) ?;
555+ let s = optimized. display_indent_schema ( ) . to_string ( ) ;
556+ assert ! ( s. contains( "Inner Join" ) , "expected join rewrite, got:\n {s}" ) ;
557+ assert ! (
558+ s. contains( "Filter: CAST(test.b AS Int64) IS NOT NULL" ) ,
559+ "expected null filter on cast(b), got:\n {s}"
560+ ) ;
561+ assert ! (
562+ s. contains( "Filter: CAST(test.c AS Int64) IS NOT NULL" ) ,
563+ "expected null filter on cast(c), got:\n {s}"
564+ ) ;
565+ Ok ( ( ) )
566+ }
567+
568+ #[ test]
569+ fn does_not_rewrite_grouping_sets_multi_distinct ( ) -> Result < ( ) > {
570+ let table_scan = test_table_scan ( ) ?;
571+ let group_expr = vec ! [ Expr :: GroupingSet ( GroupingSet :: GroupingSets ( vec![ vec![
572+ col( "a" ) ,
573+ ] ] ) ) ] ;
574+ let aggr_expr = vec ! [ count_distinct( col( "b" ) ) , count_distinct( col( "c" ) ) ] ;
575+ let plan = LogicalPlan :: Aggregate ( Aggregate :: try_new (
576+ Arc :: new ( table_scan) ,
577+ group_expr,
578+ aggr_expr,
579+ ) ?) ;
580+ let before = plan. display_indent_schema ( ) . to_string ( ) ;
581+ let optimized =
582+ optimize_with_rule ( plan, Arc :: new ( MultiDistinctCountRewrite :: new ( ) ) ) ?;
583+ let after = optimized. display_indent_schema ( ) . to_string ( ) ;
584+ assert_eq ! ( before, after) ;
585+ Ok ( ( ) )
586+ }
587+
420588 #[ test]
421589 fn does_not_rewrite_mixed_agg ( ) -> Result < ( ) > {
422590 let table_scan = test_table_scan ( ) ?;
0 commit comments