@@ -59,6 +59,7 @@ use std::{
5959 ops:: BitAnd ,
6060 sync:: Arc ,
6161} ;
62+ use datafusion_functions_aggregate_common:: aggregate:: count_distinct:: PrimitiveDistinctCountGroupsAccumulator ;
6263
6364make_udaf_expr_and_func ! (
6465 Count ,
@@ -336,18 +337,27 @@ impl AggregateUDFImpl for Count {
336337 }
337338
338339 fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
339- // groups accumulator only supports `COUNT(c1)`, not
340- // `COUNT(c1, c2)`, etc
341- if args. is_distinct {
342- return false ;
343- }
344340 args. exprs . len ( ) == 1
345341 }
346342
347343 fn create_groups_accumulator (
348344 & self ,
349- _args : AccumulatorArgs ,
345+ args : AccumulatorArgs ,
350346 ) -> Result < Box < dyn GroupsAccumulator > > {
347+ if ( args. is_distinct ) {
348+ let data_type = args. expr_fields [ 0 ] . data_type ( ) ;
349+ return match data_type{
350+ DataType :: Int8 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < Int8Type > :: new ( data_type. clone ( ) ) ) ) ,
351+ DataType :: Int16 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < Int16Type > :: new ( data_type. clone ( ) ) ) ) ,
352+ DataType :: Int32 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < Int32Type > :: new ( data_type. clone ( ) ) ) ) ,
353+ DataType :: Int64 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < Int64Type > :: new ( data_type. clone ( ) ) ) ) ,
354+ DataType :: UInt8 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < UInt8Type > :: new ( data_type. clone ( ) ) ) ) ,
355+ DataType :: UInt16 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < UInt16Type > :: new ( data_type. clone ( ) ) ) ) ,
356+ DataType :: UInt32 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < UInt32Type > :: new ( data_type. clone ( ) ) ) ) ,
357+ DataType :: UInt64 => Ok ( Box :: new ( PrimitiveDistinctCountGroupsAccumulator :: < UInt64Type > :: new ( data_type. clone ( ) ) ) ) ,
358+ _ => not_impl_err ! ( "GroupsAccumulator not supported for COUNT(DISTINCT) with {}" , data_type) ,
359+ }
360+ }
351361 // instantiate specialized accumulator
352362 Ok ( Box :: new ( CountGroupsAccumulator :: new ( ) ) )
353363 }
0 commit comments