Skip to content

Commit 3e60f15

Browse files
committed
fixed null handleing & added test
1 parent 92ec4b6 commit 3e60f15

1 file changed

Lines changed: 58 additions & 6 deletions

File tree

  • datafusion/physical-plan/src/aggregates/group_values/single_group_by

datafusion/physical-plan/src/aggregates/group_values/single_group_by/dictionary.rs

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ impl<K:ArrowDictionaryKeyType + Send> GroupValues for GroupValuesDictionary<K>
104104
// A. if it has grab the corresponding initail group integer assigned to it
105105
// B. if it has not its group integer is self.seen_elements.len - 1 and then store this mapping
106106
for i in 0..key_array.len() {
107-
if key_array.is_null(i){
108-
// Null case -> skip!
109-
continue;
110-
}
111-
let key = key_array.value(i).as_usize();
112-
let scalar_value = ScalarValue::try_from_array(values, key)?;
107+
let scalar_value = match key_array.is_null(i) {
108+
true => ScalarValue::try_from(&self.value_dt)?,
109+
false => {
110+
let key = key_array.value(i).to_usize().unwrap();
111+
ScalarValue::try_from_array(values, key)?},
112+
113+
};
113114
let group_id = if let Some(group_id) = self.unique_dict_value_mapping.get(&scalar_value) {
114115
*group_id
115116
} else {
@@ -680,6 +681,7 @@ mod group_values_trait_test {
680681

681682
mod data_correctness {
682683
use super::*;
684+
use arrow::array::Int32Array;
683685

684686
pub fn test_group_assignment_order(group_values_trait_obj: &mut dyn GroupValues) {
685687
let dict_array =
@@ -842,5 +844,55 @@ mod group_values_trait_test {
842844
let mut group_values = GroupValuesDictionary::<arrow::datatypes::UInt8Type>::new(&DataType::Utf8);
843845
test_emit_restores_intern_ability(&mut group_values);
844846
}
847+
fn test_null_keys_form_single_group(group_values: &mut dyn GroupValues) -> Result<()> {
848+
// keys: [0, null, 1, null, 0]
849+
// values: ["a", "b"]
850+
// null keys should all map to the same group
851+
let keys = Int32Array::from(vec![Some(0), None, Some(1), None, Some(0)]);
852+
let values = StringArray::from(vec!["a", "b"]);
853+
let dict = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef;
854+
855+
let mut groups = Vec::new();
856+
group_values.intern(&[dict], &mut groups)?;
857+
858+
// should have 3 groups: "a", "b", null
859+
assert_eq!(group_values.len(), 3);
860+
// null rows (index 1 and 3) should map to same group
861+
assert_eq!(groups[1], groups[3]);
862+
// non null rows should map to correct groups
863+
assert_eq!(groups[0], groups[4]); // both "a"
864+
assert_ne!(groups[0], groups[2]); // "a" != "b"
865+
Ok(())
866+
}
867+
#[test]
868+
fn run_test_null_keys_form_single_group() {
869+
let mut group_values = GroupValuesDictionary::<arrow::datatypes::Int32Type>::new(&DataType::Utf8);
870+
test_null_keys_form_single_group(&mut group_values).unwrap();
871+
}
872+
873+
fn test_null_values_in_dictionary_form_single_group(group_values: &mut dyn GroupValues) -> Result<()> {
874+
// keys: [0, 1, 2, 1, 0]
875+
// values: ["a", null, "b"]
876+
// keys pointing to null value should all map to same group
877+
let keys = Int32Array::from(vec![0, 1, 2, 1, 0]);
878+
let values = StringArray::from(vec![Some("a"), None, Some("b")]);
879+
let dict = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef;
880+
881+
let mut groups = Vec::new();
882+
group_values.intern(&[dict], &mut groups)?;
883+
884+
// should have 3 groups: "a", null, "b"
885+
assert_eq!(group_values.len(), 3);
886+
// rows pointing to null value (index 1 and 3) should map to same group
887+
assert_eq!(groups[1], groups[3]);
888+
// non null rows should map correctly
889+
assert_eq!(groups[0], groups[4]); // both "a"
890+
assert_ne!(groups[0], groups[2]); // "a" != "b"
891+
Ok(())
845892
}
893+
#[test]
894+
fn run_test_null_values_in_dictionary_form_single_group() {
895+
let mut group_values = GroupValuesDictionary::<arrow::datatypes::Int32Type>::new(&DataType::Utf8);
896+
test_null_values_in_dictionary_form_single_group(&mut group_values).unwrap();
897+
}}
846898
}

0 commit comments

Comments
 (0)