@@ -104,12 +104,13 @@ impl<K:ArrowDictionaryKeyType + Send> GroupValues for GroupValuesDictionary<K>
104104 // A. if it has grab the corresponding initail group integer assigned to it
105105 // B. if it has not its group integer is self.seen_elements.len - 1 and then store this mapping
106106 for i in 0 ..key_array. len ( ) {
107- if key_array. is_null ( i) {
108- // Null case -> skip!
109- continue ;
110- }
111- let key = key_array. value ( i) . as_usize ( ) ;
112- let scalar_value = ScalarValue :: try_from_array ( values, key) ?;
107+ let scalar_value = match key_array. is_null ( i) {
108+ true => ScalarValue :: try_from ( & self . value_dt ) ?,
109+ false => {
110+ let key = key_array. value ( i) . to_usize ( ) . unwrap ( ) ;
111+ ScalarValue :: try_from_array ( values, key) ?} ,
112+
113+ } ;
113114 let group_id = if let Some ( group_id) = self . unique_dict_value_mapping . get ( & scalar_value) {
114115 * group_id
115116 } else {
@@ -680,6 +681,7 @@ mod group_values_trait_test {
680681
681682 mod data_correctness {
682683 use super :: * ;
684+ use arrow:: array:: Int32Array ;
683685
684686 pub fn test_group_assignment_order ( group_values_trait_obj : & mut dyn GroupValues ) {
685687 let dict_array =
@@ -842,5 +844,55 @@ mod group_values_trait_test {
842844 let mut group_values = GroupValuesDictionary :: < arrow:: datatypes:: UInt8Type > :: new ( & DataType :: Utf8 ) ;
843845 test_emit_restores_intern_ability ( & mut group_values) ;
844846 }
847+ fn test_null_keys_form_single_group ( group_values : & mut dyn GroupValues ) -> Result < ( ) > {
848+ // keys: [0, null, 1, null, 0]
849+ // values: ["a", "b"]
850+ // null keys should all map to the same group
851+ let keys = Int32Array :: from ( vec ! [ Some ( 0 ) , None , Some ( 1 ) , None , Some ( 0 ) ] ) ;
852+ let values = StringArray :: from ( vec ! [ "a" , "b" ] ) ;
853+ let dict = Arc :: new ( DictionaryArray :: new ( keys, Arc :: new ( values) ) ) as ArrayRef ;
854+
855+ let mut groups = Vec :: new ( ) ;
856+ group_values. intern ( & [ dict] , & mut groups) ?;
857+
858+ // should have 3 groups: "a", "b", null
859+ assert_eq ! ( group_values. len( ) , 3 ) ;
860+ // null rows (index 1 and 3) should map to same group
861+ assert_eq ! ( groups[ 1 ] , groups[ 3 ] ) ;
862+ // non null rows should map to correct groups
863+ assert_eq ! ( groups[ 0 ] , groups[ 4 ] ) ; // both "a"
864+ assert_ne ! ( groups[ 0 ] , groups[ 2 ] ) ; // "a" != "b"
865+ Ok ( ( ) )
866+ }
867+ #[ test]
868+ fn run_test_null_keys_form_single_group ( ) {
869+ let mut group_values = GroupValuesDictionary :: < arrow:: datatypes:: Int32Type > :: new ( & DataType :: Utf8 ) ;
870+ test_null_keys_form_single_group ( & mut group_values) . unwrap ( ) ;
871+ }
872+
873+ fn test_null_values_in_dictionary_form_single_group ( group_values : & mut dyn GroupValues ) -> Result < ( ) > {
874+ // keys: [0, 1, 2, 1, 0]
875+ // values: ["a", null, "b"]
876+ // keys pointing to null value should all map to same group
877+ let keys = Int32Array :: from ( vec ! [ 0 , 1 , 2 , 1 , 0 ] ) ;
878+ let values = StringArray :: from ( vec ! [ Some ( "a" ) , None , Some ( "b" ) ] ) ;
879+ let dict = Arc :: new ( DictionaryArray :: new ( keys, Arc :: new ( values) ) ) as ArrayRef ;
880+
881+ let mut groups = Vec :: new ( ) ;
882+ group_values. intern ( & [ dict] , & mut groups) ?;
883+
884+ // should have 3 groups: "a", null, "b"
885+ assert_eq ! ( group_values. len( ) , 3 ) ;
886+ // rows pointing to null value (index 1 and 3) should map to same group
887+ assert_eq ! ( groups[ 1 ] , groups[ 3 ] ) ;
888+ // non null rows should map correctly
889+ assert_eq ! ( groups[ 0 ] , groups[ 4 ] ) ; // both "a"
890+ assert_ne ! ( groups[ 0 ] , groups[ 2 ] ) ; // "a" != "b"
891+ Ok ( ( ) )
845892 }
893+ #[ test]
894+ fn run_test_null_values_in_dictionary_form_single_group ( ) {
895+ let mut group_values = GroupValuesDictionary :: < arrow:: datatypes:: Int32Type > :: new ( & DataType :: Utf8 ) ;
896+ test_null_values_in_dictionary_form_single_group ( & mut group_values) . unwrap ( ) ;
897+ } }
846898}
0 commit comments