Skip to content

Commit 3a58334

Browse files
committed
functions: Add dict support for get field
1 parent 4389f14 commit 3a58334

1 file changed

Lines changed: 210 additions & 0 deletions

File tree

datafusion/functions/src/core/getfield.rs

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,53 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
198198
let string_value = name.try_as_str().flatten().map(|s| s.to_string());
199199

200200
match (array.data_type(), name, string_value) {
201+
// Dictionary-encoded struct: extract the field from the dictionary's
202+
// values (the deduplicated struct array) and rebuild a dictionary with
203+
// the same keys. This preserves dictionary encoding without expanding.
204+
(DataType::Dictionary(key_type, value_type), _, Some(field_name))
205+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
206+
{
207+
// Downcast to DictionaryArray to access keys and values without
208+
// materializing the dictionary.
209+
macro_rules! extract_dict_field {
210+
($key_ty:ty) => {{
211+
let dict = array
212+
.as_any()
213+
.downcast_ref::<arrow::array::DictionaryArray<$key_ty>>()
214+
.ok_or_else(|| {
215+
datafusion_common::DataFusionError::Internal(format!(
216+
"Failed to downcast dictionary with key type {}",
217+
key_type
218+
))
219+
})?;
220+
let values_struct = as_struct_array(dict.values())?;
221+
let field_col =
222+
values_struct.column_by_name(&field_name).ok_or_else(|| {
223+
datafusion_common::DataFusionError::Execution(format!(
224+
"Field {field_name} not found in dictionary struct"
225+
))
226+
})?;
227+
// Rebuild dictionary: same keys, extracted field as values.
228+
let new_dict = arrow::array::DictionaryArray::<$key_ty>::try_new(
229+
dict.keys().clone(),
230+
Arc::clone(field_col),
231+
)?;
232+
Ok(ColumnarValue::Array(Arc::new(new_dict)))
233+
}};
234+
}
235+
236+
match key_type.as_ref() {
237+
DataType::Int8 => extract_dict_field!(arrow::datatypes::Int8Type),
238+
DataType::Int16 => extract_dict_field!(arrow::datatypes::Int16Type),
239+
DataType::Int32 => extract_dict_field!(arrow::datatypes::Int32Type),
240+
DataType::Int64 => extract_dict_field!(arrow::datatypes::Int64Type),
241+
DataType::UInt8 => extract_dict_field!(arrow::datatypes::UInt8Type),
242+
DataType::UInt16 => extract_dict_field!(arrow::datatypes::UInt16Type),
243+
DataType::UInt32 => extract_dict_field!(arrow::datatypes::UInt32Type),
244+
DataType::UInt64 => extract_dict_field!(arrow::datatypes::UInt64Type),
245+
other => exec_err!("Unsupported dictionary key type: {other}"),
246+
}
247+
}
201248
(DataType::Map(_, _), ScalarValue::List(arr), _) => {
202249
let key_array: Arc<dyn Array> = arr;
203250
process_map_array(&array, key_array)
@@ -333,6 +380,42 @@ impl ScalarUDFImpl for GetFieldFunc {
333380
}
334381
}
335382
}
383+
// Dictionary-encoded struct: resolve the child field from
384+
// the underlying struct, then wrap the result back in the
385+
// same Dictionary type so the promised type matches execution.
386+
DataType::Dictionary(key_type, value_type)
387+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
388+
{
389+
let DataType::Struct(fields) = value_type.as_ref() else {
390+
unreachable!()
391+
};
392+
let field_name = sv
393+
.as_ref()
394+
.and_then(|sv| {
395+
sv.try_as_str().flatten().filter(|s| !s.is_empty())
396+
})
397+
.ok_or_else(|| {
398+
datafusion_common::DataFusionError::Execution(
399+
"Field name must be a non-empty string".to_string(),
400+
)
401+
})?;
402+
403+
let child_field = fields
404+
.iter()
405+
.find(|f| f.name() == field_name)
406+
.ok_or_else(|| {
407+
plan_datafusion_err!("Field {field_name} not found in struct")
408+
})?;
409+
410+
let nullable =
411+
current_field.is_nullable() || child_field.is_nullable();
412+
let dict_type = DataType::Dictionary(
413+
key_type.clone(),
414+
Box::new(child_field.data_type().clone()),
415+
);
416+
current_field =
417+
Arc::new(Field::new(child_field.name(), dict_type, nullable));
418+
}
336419
DataType::Struct(fields) => {
337420
let field_name = sv
338421
.as_ref()
@@ -560,6 +643,133 @@ mod tests {
560643
Ok(())
561644
}
562645

646+
#[test]
647+
fn test_get_field_dict_encoded_struct() -> Result<()> {
648+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
649+
use arrow::datatypes::UInt32Type;
650+
651+
let names = Arc::new(StringArray::from(vec!["main", "foo", "bar"])) as ArrayRef;
652+
let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
653+
654+
let struct_fields: Fields = vec![
655+
Field::new("name", DataType::Utf8, false),
656+
Field::new("id", DataType::Int32, false),
657+
]
658+
.into();
659+
660+
let values_struct =
661+
Arc::new(StructArray::new(struct_fields, vec![names, ids], None)) as ArrayRef;
662+
663+
let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]);
664+
let dict = DictionaryArray::<UInt32Type>::try_new(keys, values_struct)?;
665+
666+
let base = ColumnarValue::Array(Arc::new(dict));
667+
let key = ScalarValue::Utf8(Some("name".to_string()));
668+
669+
let result = extract_single_field(base, key)?;
670+
let result_array = result.into_array(5)?;
671+
672+
assert!(
673+
matches!(result_array.data_type(), DataType::Dictionary(_, _)),
674+
"expected dictionary output, got {:?}",
675+
result_array.data_type()
676+
);
677+
678+
let result_dict = result_array
679+
.as_any()
680+
.downcast_ref::<DictionaryArray<UInt32Type>>()
681+
.unwrap();
682+
assert_eq!(result_dict.values().len(), 3);
683+
assert_eq!(result_dict.len(), 5);
684+
685+
let resolved = arrow::compute::cast(&result_array, &DataType::Utf8)?;
686+
let string_arr = resolved.as_any().downcast_ref::<StringArray>().unwrap();
687+
assert_eq!(string_arr.value(0), "main");
688+
assert_eq!(string_arr.value(1), "foo");
689+
assert_eq!(string_arr.value(2), "bar");
690+
assert_eq!(string_arr.value(3), "main");
691+
assert_eq!(string_arr.value(4), "foo");
692+
693+
Ok(())
694+
}
695+
696+
#[test]
697+
fn test_get_field_nested_dict_struct() -> Result<()> {
698+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
699+
use arrow::datatypes::UInt32Type;
700+
701+
let func_names = Arc::new(StringArray::from(vec!["main", "foo"])) as ArrayRef;
702+
let func_files = Arc::new(StringArray::from(vec!["main.c", "foo.c"])) as ArrayRef;
703+
let func_fields: Fields = vec![
704+
Field::new("name", DataType::Utf8, false),
705+
Field::new("file", DataType::Utf8, false),
706+
]
707+
.into();
708+
let func_struct = Arc::new(StructArray::new(
709+
func_fields.clone(),
710+
vec![func_names, func_files],
711+
None,
712+
)) as ArrayRef;
713+
let func_dict = Arc::new(DictionaryArray::<UInt32Type>::try_new(
714+
UInt32Array::from(vec![0u32, 1, 0]),
715+
func_struct,
716+
)?) as ArrayRef;
717+
718+
let line_nums = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef;
719+
let line_fields: Fields = vec![
720+
Field::new("num", DataType::Int32, false),
721+
Field::new(
722+
"function",
723+
DataType::Dictionary(
724+
Box::new(DataType::UInt32),
725+
Box::new(DataType::Struct(func_fields)),
726+
),
727+
false,
728+
),
729+
]
730+
.into();
731+
let line_struct = StructArray::new(line_fields, vec![line_nums, func_dict], None);
732+
733+
let base = ColumnarValue::Array(Arc::new(line_struct));
734+
735+
let func_result =
736+
extract_single_field(base, ScalarValue::Utf8(Some("function".to_string())))?;
737+
738+
let func_array = func_result.into_array(3)?;
739+
assert!(
740+
matches!(func_array.data_type(), DataType::Dictionary(_, _)),
741+
"expected dictionary for function, got {:?}",
742+
func_array.data_type()
743+
);
744+
745+
let name_result = extract_single_field(
746+
ColumnarValue::Array(func_array),
747+
ScalarValue::Utf8(Some("name".to_string())),
748+
)?;
749+
let name_array = name_result.into_array(3)?;
750+
751+
assert!(
752+
matches!(name_array.data_type(), DataType::Dictionary(_, _)),
753+
"expected dictionary for name, got {:?}",
754+
name_array.data_type()
755+
);
756+
757+
let name_dict = name_array
758+
.as_any()
759+
.downcast_ref::<DictionaryArray<UInt32Type>>()
760+
.unwrap();
761+
assert_eq!(name_dict.values().len(), 2);
762+
assert_eq!(name_dict.len(), 3);
763+
764+
let resolved = arrow::compute::cast(&name_array, &DataType::Utf8)?;
765+
let strings = resolved.as_any().downcast_ref::<StringArray>().unwrap();
766+
assert_eq!(strings.value(0), "main");
767+
assert_eq!(strings.value(1), "foo");
768+
assert_eq!(strings.value(2), "main");
769+
770+
Ok(())
771+
}
772+
563773
#[test]
564774
fn test_placement_literal_key() {
565775
let func = GetFieldFunc::new();

0 commit comments

Comments
 (0)