1616// under the License.
1717
1818use arrow:: array:: {
19- Array , ArrayRef , AsArray , BooleanArray , Int64Array , ListBuilder , PrimitiveBuilder ,
19+ ArrayRef , AsArray , BooleanArray , Int64Array , ListArray , PrimitiveArray ,
2020} ;
21- use arrow:: datatypes:: ArrowPrimitiveType ;
21+ use arrow:: buffer:: OffsetBuffer ;
22+ use arrow:: datatypes:: { ArrowPrimitiveType , Field } ;
2223use datafusion_common:: HashSet ;
2324use datafusion_common:: hash_utils:: RandomState ;
2425use datafusion_expr_common:: groups_accumulator:: { EmitTo , GroupsAccumulator } ;
2526use std:: hash:: Hash ;
2627use std:: mem:: size_of;
2728use std:: sync:: Arc ;
2829
30+ use crate :: aggregate:: groups_accumulator:: accumulate:: accumulate;
31+
2932pub struct PrimitiveDistinctCountGroupsAccumulator < T : ArrowPrimitiveType >
3033where
3134 T :: Native : Eq + Hash ,
3235{
33- /// Count distinct per group.
34- values : Vec < HashSet < T :: Native , RandomState > > ,
36+ seen : HashSet < ( usize , T :: Native ) , RandomState > ,
37+ num_groups : usize ,
3538}
3639
37- impl < T : ArrowPrimitiveType > Default for PrimitiveDistinctCountGroupsAccumulator < T >
40+ impl < T : ArrowPrimitiveType > PrimitiveDistinctCountGroupsAccumulator < T >
3841where
3942 T :: Native : Eq + Hash ,
4043{
41- fn default ( ) -> Self {
42- Self :: new ( )
44+ pub fn new ( ) -> Self {
45+ Self {
46+ seen : HashSet :: default ( ) ,
47+ num_groups : 0 ,
48+ }
49+ }
50+
51+ fn emit_to_values ( & mut self , emit_to : EmitTo ) -> Vec < Vec < T :: Native > > {
52+ let num_emitted = match emit_to {
53+ EmitTo :: All => self . num_groups ,
54+ EmitTo :: First ( n) => n,
55+ } ;
56+
57+ let mut group_values: Vec < Vec < T :: Native > > = vec ! [ Vec :: new( ) ; num_emitted] ;
58+ let mut remaining = HashSet :: default ( ) ;
59+
60+ for ( group_idx, value) in self . seen . drain ( ) {
61+ if group_idx < num_emitted {
62+ group_values[ group_idx] . push ( value) ;
63+ } else {
64+ remaining. insert ( ( group_idx - num_emitted, value) ) ;
65+ }
66+ }
67+
68+ self . seen = remaining;
69+ match emit_to {
70+ EmitTo :: All => self . num_groups = 0 ,
71+ EmitTo :: First ( n) => self . num_groups = self . num_groups . saturating_sub ( n) ,
72+ }
73+
74+ group_values
4375 }
4476}
4577
46- impl < T : ArrowPrimitiveType > PrimitiveDistinctCountGroupsAccumulator < T >
78+ impl < T : ArrowPrimitiveType > Default for PrimitiveDistinctCountGroupsAccumulator < T >
4779where
4880 T :: Native : Eq + Hash ,
4981{
50- pub fn new ( ) -> Self {
51- Self { values : Vec :: new ( ) }
82+ fn default ( ) -> Self {
83+ Self :: new ( )
5284 }
5385}
5486
@@ -64,47 +96,40 @@ where
6496 opt_filter : Option < & BooleanArray > ,
6597 total_num_groups : usize ,
6698 ) -> datafusion_common:: Result < ( ) > {
67- self . values . resize_with ( total_num_groups, HashSet :: default) ;
68- debug_assert_eq ! ( values. len( ) , 1 , "multiple arguments are not supported" ) ;
69-
99+ debug_assert_eq ! ( values. len( ) , 1 ) ;
100+ self . num_groups = self . num_groups . max ( total_num_groups) ;
70101 let arr = values[ 0 ] . as_primitive :: < T > ( ) ;
71-
72- for ( idx, group_idx) in group_indices. iter ( ) . enumerate ( ) {
73- if let Some ( filter) = opt_filter
74- && !filter. value ( idx)
75- {
76- continue ;
77- }
78- if arr. is_valid ( idx) {
79- let value = arr. value ( idx) ;
80- self . values [ * group_idx] . insert ( value) ;
81- }
82- }
83-
102+ accumulate ( group_indices, arr, opt_filter, |group_idx, value| {
103+ self . seen . insert ( ( group_idx, value) ) ;
104+ } ) ;
84105 Ok ( ( ) )
85106 }
86107
87108 fn evaluate ( & mut self , emit_to : EmitTo ) -> datafusion_common:: Result < ArrayRef > {
88- let counts: Vec < i64 > = emit_to
89- . take_needed ( & mut self . values )
90- . iter ( )
91- . map ( |groups| groups. len ( ) as i64 )
92- . collect ( ) ;
93-
109+ let group_values = self . emit_to_values ( emit_to) ;
110+ let counts: Vec < i64 > = group_values. iter ( ) . map ( |v| v. len ( ) as i64 ) . collect ( ) ;
94111 Ok ( Arc :: new ( Int64Array :: from ( counts) ) )
95112 }
96113
97114 fn state ( & mut self , emit_to : EmitTo ) -> datafusion_common:: Result < Vec < ArrayRef > > {
98- let hash_sets = emit_to. take_needed ( & mut self . values ) ;
99- let mut builder = ListBuilder :: new ( PrimitiveBuilder :: < T > :: new ( ) ) ;
115+ let group_values = self . emit_to_values ( emit_to) ;
100116
101- for set in hash_sets {
102- for value in set {
103- builder . values ( ) . append_value ( value ) ;
104- }
105- builder . append ( true ) ;
117+ let mut offsets = vec ! [ 0i32 ] ;
118+ let mut all_values = Vec :: new ( ) ;
119+ for values in & group_values {
120+ all_values . extend ( values . iter ( ) . copied ( ) ) ;
121+ offsets . push ( all_values . len ( ) as i32 ) ;
106122 }
107- Ok ( vec ! [ Arc :: new( builder. finish( ) ) ] )
123+
124+ let values_array = Arc :: new ( PrimitiveArray :: < T > :: from_iter_values ( all_values) ) ;
125+ let list_array = ListArray :: new (
126+ Arc :: new ( Field :: new_list_field ( T :: DATA_TYPE , true ) ) ,
127+ OffsetBuffer :: new ( offsets. into ( ) ) ,
128+ values_array,
129+ None ,
130+ ) ;
131+
132+ Ok ( vec ! [ Arc :: new( list_array) ] )
108133 }
109134
110135 fn merge_batch (
@@ -114,26 +139,23 @@ where
114139 _opt_filter : Option < & BooleanArray > ,
115140 total_num_groups : usize ,
116141 ) -> datafusion_common:: Result < ( ) > {
117- self . values . resize_with ( total_num_groups, HashSet :: default) ;
142+ debug_assert_eq ! ( values. len( ) , 1 ) ;
143+ self . num_groups = self . num_groups . max ( total_num_groups) ;
118144 let list_array = values[ 0 ] . as_list :: < i32 > ( ) ;
119145
120146 for ( row_idx, group_idx) in group_indices. iter ( ) . enumerate ( ) {
121147 let inner = list_array. value ( row_idx) ;
122- let inner_set = inner. as_primitive :: < T > ( ) ;
123- for i in 0 ..inner . len ( ) {
124- self . values [ * group_idx ] . insert ( inner_set . value ( i ) ) ;
148+ let inner_arr = inner. as_primitive :: < T > ( ) ;
149+ for value in inner_arr . values ( ) . iter ( ) {
150+ self . seen . insert ( ( * group_idx , * value ) ) ;
125151 }
126152 }
153+
127154 Ok ( ( ) )
128155 }
129156
130157 fn size ( & self ) -> usize {
131158 size_of :: < Self > ( )
132- + self . values . capacity ( ) * size_of :: < HashSet < T :: Native , RandomState > > ( )
133- + self
134- . values
135- . iter ( )
136- . map ( |s| s. capacity ( ) * size_of :: < T :: Native > ( ) )
137- . sum :: < usize > ( )
159+ + self . seen . capacity ( ) * ( size_of :: < ( usize , T :: Native ) > ( ) + size_of :: < u64 > ( ) )
138160 }
139161}
0 commit comments