@@ -38,7 +38,7 @@ use crate::{
3838use datafusion_common:: config:: ConfigOptions ;
3939use datafusion_physical_expr:: utils:: collect_columns;
4040use parking_lot:: Mutex ;
41- use std:: collections:: HashSet ;
41+ use std:: collections:: { HashMap , HashSet } ;
4242
4343use arrow:: array:: { ArrayRef , UInt8Array , UInt16Array , UInt32Array , UInt64Array } ;
4444use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
@@ -1937,15 +1937,53 @@ fn evaluate_optional(
19371937 . collect ( )
19381938}
19391939
1940- fn group_id_array ( group : & [ bool ] , batch : & RecordBatch ) -> Result < ArrayRef > {
1940+ fn group_id_array (
1941+ group : & [ bool ] ,
1942+ group_ordinal : usize ,
1943+ batch : & RecordBatch ,
1944+ ) -> Result < ArrayRef > {
19411945 if group. len ( ) > 64 {
19421946 return not_impl_err ! (
19431947 "Grouping sets with more than 64 columns are not supported"
19441948 ) ;
19451949 }
1950+ let width_bits = if group. len ( ) <= 8 {
1951+ 8
1952+ } else if group. len ( ) <= 16 {
1953+ 16
1954+ } else if group. len ( ) <= 32 {
1955+ 32
1956+ } else {
1957+ 64
1958+ } ;
1959+ let extra_bits = width_bits - group. len ( ) ;
1960+ if extra_bits == 0 && group_ordinal > 0 {
1961+ return not_impl_err ! (
1962+ "Duplicate grouping sets with more than {} grouping columns are not supported" ,
1963+ width_bits
1964+ ) ;
1965+ }
1966+ if extra_bits < usize:: BITS as usize {
1967+ let max_group_ordinal = 1usize << extra_bits;
1968+ if group_ordinal >= max_group_ordinal {
1969+ return not_impl_err ! (
1970+ "Duplicate grouping sets exceed the supported grouping id capacity"
1971+ ) ;
1972+ }
1973+ }
19461974 let group_id = group. iter ( ) . fold ( 0u64 , |acc, & is_null| {
19471975 ( acc << 1 ) | if is_null { 1 } else { 0 }
19481976 } ) ;
1977+ let group_id = if group. len ( ) == 64 {
1978+ if group_ordinal > 0 {
1979+ return not_impl_err ! (
1980+ "Duplicate grouping sets with 64 grouping columns are not supported"
1981+ ) ;
1982+ }
1983+ group_id
1984+ } else {
1985+ ( ( group_ordinal as u64 ) << group. len ( ) ) | group_id
1986+ } ;
19491987 let num_rows = batch. num_rows ( ) ;
19501988 if group. len ( ) <= 8 {
19511989 Ok ( Arc :: new ( UInt8Array :: from ( vec ! [ group_id as u8 ; num_rows] ) ) )
@@ -1972,6 +2010,7 @@ pub fn evaluate_group_by(
19722010 group_by : & PhysicalGroupBy ,
19732011 batch : & RecordBatch ,
19742012) -> Result < Vec < Vec < ArrayRef > > > {
2013+ let mut group_ordinals: HashMap < Vec < bool > , usize > = HashMap :: new ( ) ;
19752014 let exprs = evaluate_expressions_to_arrays (
19762015 group_by. expr . iter ( ) . map ( |( expr, _) | expr) ,
19772016 batch,
@@ -1985,6 +2024,10 @@ pub fn evaluate_group_by(
19852024 . groups
19862025 . iter ( )
19872026 . map ( |group| {
2027+ let group_ordinal = group_ordinals. entry ( group. clone ( ) ) . or_insert ( 0 ) ;
2028+ let current_group_ordinal = * group_ordinal;
2029+ * group_ordinal += 1 ;
2030+
19882031 let mut group_values = Vec :: with_capacity ( group_by. num_group_exprs ( ) ) ;
19892032 group_values. extend ( group. iter ( ) . enumerate ( ) . map ( |( idx, is_null) | {
19902033 if * is_null {
@@ -1994,7 +2037,7 @@ pub fn evaluate_group_by(
19942037 }
19952038 } ) ) ;
19962039 if !group_by. is_single ( ) {
1997- group_values. push ( group_id_array ( group, batch) ?) ;
2040+ group_values. push ( group_id_array ( group, current_group_ordinal , batch) ?) ;
19982041 }
19992042 Ok ( group_values)
20002043 } )
@@ -2018,8 +2061,8 @@ mod tests {
20182061 use crate :: test:: exec:: { BlockingExec , assert_strong_count_converges_to_zero} ;
20192062
20202063 use arrow:: array:: {
2021- DictionaryArray , Float32Array , Float64Array , Int32Array , Int64Array , StructArray ,
2022- UInt32Array , UInt64Array ,
2064+ DictionaryArray , Float32Array , Float64Array , Int32Array , Int64Array , StringArray ,
2065+ StructArray , UInt32Array , UInt64Array ,
20232066 } ;
20242067 use arrow:: compute:: { SortOptions , concat_batches} ;
20252068 use arrow:: datatypes:: { DataType , Int32Type } ;
@@ -3478,6 +3521,71 @@ mod tests {
34783521 Ok ( ( ) )
34793522 }
34803523
3524+ #[ tokio:: test]
3525+ async fn grouping_sets_preserve_duplicate_groups ( ) -> Result < ( ) > {
3526+ let schema = Arc :: new ( Schema :: new ( vec ! [
3527+ Field :: new( "deptno" , DataType :: Int32 , false ) ,
3528+ Field :: new( "job" , DataType :: Utf8 , true ) ,
3529+ Field :: new( "sal" , DataType :: Float64 , true ) ,
3530+ Field :: new( "comm" , DataType :: Float64 , true ) ,
3531+ ] ) ) ;
3532+
3533+ let input = TestMemoryExec :: try_new_exec (
3534+ & [ vec ! [ RecordBatch :: try_new(
3535+ Arc :: clone( & schema) ,
3536+ vec![
3537+ Arc :: new( Int32Array :: from( vec![ 10 , 20 ] ) ) ,
3538+ Arc :: new( StringArray :: from( vec![ Some ( "CLERK" ) , Some ( "MANAGER" ) ] ) ) ,
3539+ Arc :: new( Float64Array :: from( vec![ 1300.0 , 3000.0 ] ) ) ,
3540+ Arc :: new( Float64Array :: from( vec![ None , None ] ) ) ,
3541+ ] ,
3542+ ) ?] ] ,
3543+ Arc :: clone ( & schema) ,
3544+ None ,
3545+ ) ?;
3546+
3547+ let group_by = PhysicalGroupBy :: new (
3548+ vec ! [
3549+ ( col( "deptno" , & schema) ?, "deptno" . to_string( ) ) ,
3550+ ( col( "job" , & schema) ?, "job" . to_string( ) ) ,
3551+ ( col( "sal" , & schema) ?, "sal" . to_string( ) ) ,
3552+ ] ,
3553+ vec ! [
3554+ ( lit( ScalarValue :: Int32 ( None ) ) , "deptno" . to_string( ) ) ,
3555+ ( lit( ScalarValue :: Utf8 ( None ) ) , "job" . to_string( ) ) ,
3556+ ( lit( ScalarValue :: Float64 ( None ) ) , "sal" . to_string( ) ) ,
3557+ ] ,
3558+ vec ! [
3559+ vec![ false , false , true ] ,
3560+ vec![ false , true , false ] ,
3561+ vec![ false , false , true ] ,
3562+ ] ,
3563+ true ,
3564+ ) ;
3565+
3566+ let aggr_exprs: Vec < Arc < AggregateFunctionExpr > > = vec ! [ Arc :: new(
3567+ AggregateExprBuilder :: new( count_udaf( ) , vec![ lit( 1 ) ] )
3568+ . schema( Arc :: clone( & schema) )
3569+ . alias( "COUNT(1)" )
3570+ . build( ) ?,
3571+ ) ] ;
3572+
3573+ let aggregate_exec = Arc :: new ( AggregateExec :: try_new (
3574+ AggregateMode :: Single ,
3575+ group_by,
3576+ aggr_exprs,
3577+ vec ! [ None ] ,
3578+ Arc :: clone ( & input) as Arc < dyn ExecutionPlan > ,
3579+ Arc :: clone ( & schema) ,
3580+ ) ?) ;
3581+
3582+ let output =
3583+ collect ( aggregate_exec. execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?) . await ?;
3584+ let batch = concat_batches ( & output[ 0 ] . schema ( ) , & output) ?;
3585+ assert_eq ! ( batch. num_rows( ) , 6 ) ;
3586+ Ok ( ( ) )
3587+ }
3588+
34813589 // test for https://github.com/apache/datafusion/issues/13949
34823590 async fn run_test_with_spill_pool_if_necessary (
34833591 pool_size : usize ,
0 commit comments