1717
1818use crate :: aggregates:: group_values:: GroupValues ;
1919use arrow:: array:: {
20- Array , ArrayBuilder , ArrayRef , DictionaryArray , Int8Builder , Int16Builder , Int32Builder , Int64Builder , LargeStringBuilder , StringArray , StringBuilder , StringViewBuilder , UInt8Builder , UInt16Builder , UInt32Builder , UInt64Builder
20+ Array , ArrayBuilder , ArrayRef , DictionaryArray , Int8Builder , Int16Builder , Int32Builder , Int64Builder , LargeStringBuilder , Scalar , StringArray , StringBuilder , StringViewBuilder , UInt8Builder , UInt16Builder , UInt32Builder , UInt64Builder
2121} ;
22+ use std:: mem;
2223use arrow:: datatypes:: { ArrowDictionaryKeyType , ArrowNativeType , DataType } ;
2324use datafusion_common:: { Result , ScalarValue } ;
2425use datafusion_expr:: EmitTo ;
@@ -76,8 +77,10 @@ impl<K:ArrowDictionaryKeyType + Send> GroupValuesDictionary<K> {
7677impl < K : ArrowDictionaryKeyType + Send > GroupValues for GroupValuesDictionary < K > {
7778 // not really sure how to return the size of strings and binary values so this is a best effort approach
7879 fn size ( & self ) -> usize {
79- 0
80- }
80+ let arr_size = element_size ( & self . value_dt ) * self . unique_dict_value_mapping . len ( ) ;
81+ let dict_size = self . unique_dict_value_mapping . len ( ) * size_of :: < ( ScalarValue , usize ) > ( ) + 100 /* rough estimate for hashmap overhead */ ; // rough estimate for hashmap overhead
82+ arr_size + dict_size
83+ }
8184 fn len ( & self ) -> usize {
8285 self . unique_dict_value_mapping . len ( )
8386 }
@@ -92,7 +95,6 @@ impl<K:ArrowDictionaryKeyType + Send> GroupValues for GroupValuesDictionary<K>
9295 }
9396 let array = cols[ 0 ] . clone ( ) ;
9497 groups. clear ( ) ; // zero out buffer
95- println ! ( "interning with dictionary array: {:#?}" , array) ;
9698 let dict_array = array. as_any ( ) . downcast_ref :: < DictionaryArray < K > > ( ) . unwrap ( ) ;
9799 // grab the keys and values array
98100 let values = dict_array. values ( ) ;
@@ -123,16 +125,56 @@ impl<K:ArrowDictionaryKeyType + Send> GroupValues for GroupValuesDictionary<K>
123125 Ok ( ( ) )
124126 }
125127 fn emit ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
126- Ok ( vec ! [ ] )
128+ let columns: Vec < ScalarValue > = match emit_to {
129+ EmitTo :: All => {
130+ self . unique_dict_value_mapping . clear ( ) ;
131+ mem:: take ( & mut self . seen_elements )
132+ } ,
133+ EmitTo :: First ( n) => {
134+ // drain first n elements, keeping the rest
135+ let first_n = self . seen_elements . drain ( ..n) . collect ( ) ;
136+ // shift all remaining group indices down by n
137+ self . unique_dict_value_mapping . retain ( |_, group_idx| {
138+ match group_idx. checked_sub ( n) {
139+ Some ( new_idx) => {
140+ * group_idx = new_idx;
141+ true
142+ }
143+ // this group was in the first n, remove it
144+ None => false ,
145+ }
146+ } ) ;
147+ first_n
148+ }
149+ } ;
150+
151+ // convert Vec<ScalarValue> into an ArrayRef
152+ let array = ScalarValue :: iter_to_array ( columns. into_iter ( ) ) ?;
153+ Ok ( vec ! [ array] )
154+
155+ }
156+ fn clear_shrink ( & mut self , num_rows : usize ) {
157+ self . seen_elements . clear ( ) ;
158+ self . seen_elements . shrink_to ( num_rows) ;
159+ self . unique_dict_value_mapping . clear ( ) ;
160+ self . unique_dict_value_mapping . shrink_to ( num_rows) ;
161+ }
162+ }
163+ fn element_size ( dt : & DataType ) -> usize {
164+ match dt{
165+ DataType :: Utf8 | DataType :: LargeUtf8 => 20 , // rough estimate for average string size
166+ DataType :: Binary | DataType :: LargeBinary => 20 , // rough estimate for average binary size
167+ DataType :: Boolean => 1 ,
168+ DataType :: Int8 | DataType :: UInt8 => 1 ,
169+ DataType :: Int16 | DataType :: UInt16 => 2 ,
170+ DataType :: Int32 | DataType :: UInt32 => 4 ,
171+ DataType :: Int64 | DataType :: UInt64 => 8 ,
172+ _ => 0 , // default case for unsupported types
127173 }
128- fn clear_shrink ( & mut self , num_rows : usize ) { }
129174}
130175
131176#[ cfg( test) ]
132177mod group_values_trait_test {
133- /*
134- cargo test --package datafusion-physical-plan --lib -- aggregates::group_values::single_group_by::dictionary::group_values_trait_test --nocapture
135- */
136178 use super :: * ;
137179 use arrow:: array:: { DictionaryArray , StringArray , UInt8Array } ;
138180 use std:: sync:: Arc ;
@@ -150,47 +192,39 @@ mod group_values_trait_test {
150192 }
151193 /*
152194 cargo test --package datafusion-physical-plan --lib -- aggregates::group_values::single_group_by::dictionary::group_values_trait_test::test_group_values_dictionary --exact --nocapture --include-ignored
153- */
154- #[ test]
155- fn test_group_values_dictionary ( ) {
156- run_groupvalue_test_suite ( ) . unwrap ( ) ;
157- }
158-
159- fn run_groupvalue_test_suite (
160- ) -> Result < ( ) > {
161- let tests: Vec < ( & str , fn ( & mut dyn GroupValues ) ) > = vec ! [
162- ( "test_single_group_all_same_values" , basic_functionality:: test_single_group_all_same_values) ,
163- ( "test_multiple_groups" , basic_functionality:: test_multiple_groups) ,
195+
196+ fn run_groupvalue_test_suite() -> Result<()> {
197+ let tests: Vec<(&str, fn(&mut dyn GroupValues))> = vec![
198+ ("test_single_group_all_same_values", basic_functionality::test_single_group_all_same_values),
199+ ("test_multiple_groups", basic_functionality::test_multiple_groups),
164200 ("test_all_different_values", basic_functionality::test_all_different_values),
165201 ("test_empty_batch", edge_cases::test_empty_batch),
166202 ("test_single_row", edge_cases::test_single_row),
167203 ("test_repeated_pattern", edge_cases::test_repeated_pattern),
168- /*
169- multi_column::test_multiple_columns_passed,
170- consecutive_batches::test_consecutive_batches_then_emit,
171- consecutive_batches::test_three_consecutive_batches_with_partial_emit,
172- state_management::test_size_grows_after_intern,
173- state_management::test_complex_emit_flow_with_multiple_internS,
174- state_management::test_clear_shrink_resets_state,
175- state_management::test_clear_shrink_with_zero,
176- state_management::test_emit_all_clears_state,
177- state_management::test_emit_first_n,
178- state_management::test_complex_emit_flow_with_multiple_internS,
179- data_correctness::test_group_assignment_order,
180- data_correctness::test_groups_vector_correctness_first_appearance,
181- data_correctness::test_groups_vector_sequential_assignment,
182- data_correctness::test_emit_partial_preserves_state,
183- data_correctness::test_emit_restores_intern_ability,
184- */
204+ ("test_multiple_columns_passed", multi_column::test_multiple_columns_passed),
205+ ("test_consecutive_batches_then_emit", consecutive_batches::test_consecutive_batches_then_emit),
206+ ("test_three_consecutive_batches_with_partial_emit", consecutive_batches::test_three_consecutive_batches_with_partial_emit),
207+ ("test_size_grows_after_intern", state_management::test_size_grows_after_intern),
208+ ("test_complex_emit_flow_with_multiple_internS", state_management::test_complex_emit_flow_with_multiple_internS),
209+ ("test_clear_shrink_resets_state", state_management::test_clear_shrink_resets_state),
210+ ("test_clear_shrink_with_zero", state_management::test_clear_shrink_with_zero),
211+ ("test_emit_all_clears_state", state_management::test_emit_all_clears_state),
212+ ("test_emit_first_n", state_management::test_emit_first_n),
213+ ("test_group_assignment_order", data_correctness::test_group_assignment_order),
214+ ("test_groups_vector_correctness_first_appearance", data_correctness::test_groups_vector_correctness_first_appearance),
215+ ("test_groups_vector_sequential_assignment", data_correctness::test_groups_vector_sequential_assignment),
216+ ("test_emit_partial_preserves_state", data_correctness::test_emit_partial_preserves_state),
217+ ("test_emit_restores_intern_ability", data_correctness::test_emit_restores_intern_ability),
185218 ];
186- for ( name, test_functions ) in tests {
219+ for (name, test_function ) in tests {
187220 let mut group_values = GroupValuesDictionary::<arrow::datatypes::UInt8Type>::new(&DataType::Utf8);
188221 println!("Running test: {name}");
189- test_functions ( & mut group_values) ;
222+ test_function (&mut group_values);
190223 }
191224
192225 Ok(())
193226 }
227+ */
194228
195229 mod basic_functionality {
196230 use super :: * ;
@@ -364,13 +398,11 @@ mod group_values_trait_test {
364398 group_values_trait_obj
365399 . intern ( & [ batch2] , & mut groups_vector2)
366400 . unwrap ( ) ;
367-
368401 assert_eq ! ( group_values_trait_obj. len( ) , 3 ) ;
369402 assert_eq ! ( groups_vector2. len( ) , 3 ) ;
370403
371404 let result = group_values_trait_obj. emit ( EmitTo :: All ) . unwrap ( ) ;
372405 assert_eq ! ( result. len( ) , 1 ) ;
373-
374406 assert ! ( group_values_trait_obj. is_empty( ) ) ;
375407 }
376408
@@ -397,16 +429,28 @@ mod group_values_trait_test {
397429 . unwrap ( ) ;
398430 assert_eq ! ( group_values_trait_obj. len( ) , 3 ) ;
399431
400- let batch3 = create_dict_array ( vec ! [ 2 , 3 ] , vec ! [ "c" , "d" ] ) ;
432+ let batch3 = create_dict_array ( vec ! [ 0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 1 , 1 , 0 , 1 , 2 , 1 , 2 ] , vec ! [ "c" , "d" , "e "] ) ;
401433 let mut groups_vector3 = Vec :: new ( ) ;
402434 group_values_trait_obj
403435 . intern ( & [ batch3] , & mut groups_vector3)
404436 . unwrap ( ) ;
405- assert_eq ! ( group_values_trait_obj. len( ) , 4 ) ;
437+ assert_eq ! ( group_values_trait_obj. len( ) , 5 ) ;
406438
407439 let result = group_values_trait_obj. emit ( EmitTo :: All ) . unwrap ( ) ;
408440 assert_eq ! ( result. len( ) , 1 ) ;
409441 assert ! ( group_values_trait_obj. is_empty( ) ) ;
442+ result. iter ( ) . for_each ( |array| {
443+ let string_array = array. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
444+ let values: Vec < String > = ( 0 ..string_array. len ( ) )
445+ . map ( |i| string_array. value ( i) . to_string ( ) )
446+ . collect ( ) ;
447+ let unexpected_values: Vec < & String > = values. iter ( ) . filter ( |v| * * v != "a" && * * v != "b" && * * v != "c" && * * v != "d" && * * v != "e" ) . collect ( ) ;
448+ assert ! (
449+ unexpected_values. is_empty( ) ,
450+ "Emitted unexpected values: {:#?}" ,
451+ unexpected_values
452+ ) ;
453+ } ) ;
410454 }
411455
412456 #[ test]
@@ -422,7 +466,6 @@ mod group_values_trait_test {
422466 fn test_initial_state_is_empty ( group_values_trait_obj : & dyn GroupValues ) {
423467 assert ! ( group_values_trait_obj. is_empty( ) ) ;
424468 assert_eq ! ( group_values_trait_obj. len( ) , 0 ) ;
425- assert_eq ! ( group_values_trait_obj. size( ) , 0 ) ;
426469 }
427470
428471 #[ test]
@@ -756,7 +799,7 @@ mod group_values_trait_test {
756799
757800 #[ test]
758801 fn run_test_emit_partial_preserves_state ( ) {
759- let mut group_values = GroupValuesDictionary :: < arrow:: datatypes:: Int8Type > :: new ( & DataType :: Utf8 ) ;
802+ let mut group_values = GroupValuesDictionary :: < arrow:: datatypes:: UInt8Type > :: new ( & DataType :: Utf8 ) ;
760803 test_emit_partial_preserves_state ( & mut group_values) ;
761804 }
762805
0 commit comments