Skip to content

Commit 92d660d

Browse files
committed
functions: Add dict support for get field
1 parent 2b7d4f9 commit 92d660d

1 file changed

Lines changed: 224 additions & 0 deletions

File tree

datafusion/functions/src/core/getfield.rs

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,53 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
199199
let string_value = name.try_as_str().flatten().map(|s| s.to_string());
200200

201201
match (array.data_type(), name, string_value) {
202+
// Dictionary-encoded struct: extract the field from the dictionary's
203+
// values (the deduplicated struct array) and rebuild a dictionary with
204+
// the same keys. This preserves dictionary encoding without expanding.
205+
(DataType::Dictionary(key_type, value_type), _, Some(field_name))
206+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
207+
{
208+
// Downcast to DictionaryArray to access keys and values without
209+
// materializing the dictionary.
210+
macro_rules! extract_dict_field {
211+
($key_ty:ty) => {{
212+
let dict = array
213+
.as_any()
214+
.downcast_ref::<arrow::array::DictionaryArray<$key_ty>>()
215+
.ok_or_else(|| {
216+
datafusion_common::DataFusionError::Internal(format!(
217+
"Failed to downcast dictionary with key type {}",
218+
key_type
219+
))
220+
})?;
221+
let values_struct = as_struct_array(dict.values())?;
222+
let field_col =
223+
values_struct.column_by_name(&field_name).ok_or_else(|| {
224+
datafusion_common::DataFusionError::Execution(format!(
225+
"Field {field_name} not found in dictionary struct"
226+
))
227+
})?;
228+
// Rebuild dictionary: same keys, extracted field as values.
229+
let new_dict = arrow::array::DictionaryArray::<$key_ty>::try_new(
230+
dict.keys().clone(),
231+
Arc::clone(field_col),
232+
)?;
233+
Ok(ColumnarValue::Array(Arc::new(new_dict)))
234+
}};
235+
}
236+
237+
match key_type.as_ref() {
238+
DataType::Int8 => extract_dict_field!(arrow::datatypes::Int8Type),
239+
DataType::Int16 => extract_dict_field!(arrow::datatypes::Int16Type),
240+
DataType::Int32 => extract_dict_field!(arrow::datatypes::Int32Type),
241+
DataType::Int64 => extract_dict_field!(arrow::datatypes::Int64Type),
242+
DataType::UInt8 => extract_dict_field!(arrow::datatypes::UInt8Type),
243+
DataType::UInt16 => extract_dict_field!(arrow::datatypes::UInt16Type),
244+
DataType::UInt32 => extract_dict_field!(arrow::datatypes::UInt32Type),
245+
DataType::UInt64 => extract_dict_field!(arrow::datatypes::UInt64Type),
246+
other => exec_err!("Unsupported dictionary key type: {other}"),
247+
}
248+
}
202249
(DataType::Map(_, _), ScalarValue::List(arr), _) => {
203250
let key_array: Arc<dyn Array> = arr;
204251
process_map_array(&array, key_array)
@@ -338,6 +385,44 @@ impl ScalarUDFImpl for GetFieldFunc {
338385
}
339386
}
340387
}
388+
// Dictionary-encoded struct: resolve the child field from
389+
// the underlying struct, then wrap the result back in the
390+
// same Dictionary type so the promised type matches execution.
391+
DataType::Dictionary(key_type, value_type)
392+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
393+
{
394+
let DataType::Struct(fields) = value_type.as_ref() else {
395+
unreachable!()
396+
};
397+
let field_name = sv
398+
.as_ref()
399+
.and_then(|sv| {
400+
sv.try_as_str().flatten().filter(|s| !s.is_empty())
401+
})
402+
.ok_or_else(|| {
403+
datafusion_common::DataFusionError::Execution(
404+
"Field name must be a non-empty string".to_string(),
405+
)
406+
})?;
407+
408+
let child_field = fields
409+
.iter()
410+
.find(|f| f.name() == field_name)
411+
.ok_or_else(|| {
412+
plan_datafusion_err!("Field {field_name} not found in struct")
413+
})?;
414+
415+
let nullable = current_field.is_nullable() || child_field.is_nullable();
416+
let dict_type = DataType::Dictionary(
417+
key_type.clone(),
418+
Box::new(child_field.data_type().clone()),
419+
);
420+
current_field = Arc::new(Field::new(
421+
child_field.name(),
422+
dict_type,
423+
nullable,
424+
));
425+
}
341426
DataType::Struct(fields) => {
342427
let field_name = sv
343428
.as_ref()
@@ -569,6 +654,145 @@ mod tests {
569654
Ok(())
570655
}
571656

657+
#[test]
658+
fn test_get_field_dict_encoded_struct() -> Result<()> {
659+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
660+
use arrow::datatypes::UInt32Type;
661+
662+
let names = Arc::new(StringArray::from(vec!["main", "foo", "bar"])) as ArrayRef;
663+
let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
664+
665+
let struct_fields: Fields = vec![
666+
Field::new("name", DataType::Utf8, false),
667+
Field::new("id", DataType::Int32, false),
668+
]
669+
.into();
670+
671+
let values_struct =
672+
Arc::new(StructArray::new(struct_fields, vec![names, ids], None))
673+
as ArrayRef;
674+
675+
let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]);
676+
let dict = DictionaryArray::<UInt32Type>::try_new(keys, values_struct)?;
677+
678+
let base = ColumnarValue::Array(Arc::new(dict));
679+
let key = ScalarValue::Utf8(Some("name".to_string()));
680+
681+
let result = extract_single_field(base, key)?;
682+
let result_array = result.into_array(5)?;
683+
684+
assert!(
685+
matches!(result_array.data_type(), DataType::Dictionary(_, _)),
686+
"expected dictionary output, got {:?}",
687+
result_array.data_type()
688+
);
689+
690+
let result_dict = result_array
691+
.as_any()
692+
.downcast_ref::<DictionaryArray<UInt32Type>>()
693+
.unwrap();
694+
assert_eq!(result_dict.values().len(), 3);
695+
assert_eq!(result_dict.len(), 5);
696+
697+
let resolved = arrow::compute::cast(&result_array, &DataType::Utf8)?;
698+
let string_arr = resolved
699+
.as_any()
700+
.downcast_ref::<StringArray>()
701+
.unwrap();
702+
assert_eq!(string_arr.value(0), "main");
703+
assert_eq!(string_arr.value(1), "foo");
704+
assert_eq!(string_arr.value(2), "bar");
705+
assert_eq!(string_arr.value(3), "main");
706+
assert_eq!(string_arr.value(4), "foo");
707+
708+
Ok(())
709+
}
710+
711+
#[test]
712+
fn test_get_field_nested_dict_struct() -> Result<()> {
713+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
714+
use arrow::datatypes::UInt32Type;
715+
716+
let func_names =
717+
Arc::new(StringArray::from(vec!["main", "foo"])) as ArrayRef;
718+
let func_files =
719+
Arc::new(StringArray::from(vec!["main.c", "foo.c"])) as ArrayRef;
720+
let func_fields: Fields = vec![
721+
Field::new("name", DataType::Utf8, false),
722+
Field::new("file", DataType::Utf8, false),
723+
]
724+
.into();
725+
let func_struct = Arc::new(StructArray::new(
726+
func_fields.clone(),
727+
vec![func_names, func_files],
728+
None,
729+
)) as ArrayRef;
730+
let func_dict = Arc::new(DictionaryArray::<UInt32Type>::try_new(
731+
UInt32Array::from(vec![0u32, 1, 0]),
732+
func_struct,
733+
)?) as ArrayRef;
734+
735+
let line_nums = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef;
736+
let line_fields: Fields = vec![
737+
Field::new("num", DataType::Int32, false),
738+
Field::new(
739+
"function",
740+
DataType::Dictionary(
741+
Box::new(DataType::UInt32),
742+
Box::new(DataType::Struct(func_fields)),
743+
),
744+
false,
745+
),
746+
]
747+
.into();
748+
let line_struct =
749+
StructArray::new(line_fields, vec![line_nums, func_dict], None);
750+
751+
let base = ColumnarValue::Array(Arc::new(line_struct));
752+
753+
let func_result = extract_single_field(
754+
base,
755+
ScalarValue::Utf8(Some("function".to_string())),
756+
)?;
757+
758+
let func_array = func_result.into_array(3)?;
759+
assert!(
760+
matches!(func_array.data_type(), DataType::Dictionary(_, _)),
761+
"expected dictionary for function, got {:?}",
762+
func_array.data_type()
763+
);
764+
765+
let name_result = extract_single_field(
766+
ColumnarValue::Array(func_array),
767+
ScalarValue::Utf8(Some("name".to_string())),
768+
)?;
769+
let name_array = name_result.into_array(3)?;
770+
771+
assert!(
772+
matches!(name_array.data_type(), DataType::Dictionary(_, _)),
773+
"expected dictionary for name, got {:?}",
774+
name_array.data_type()
775+
);
776+
777+
let name_dict = name_array
778+
.as_any()
779+
.downcast_ref::<DictionaryArray<UInt32Type>>()
780+
.unwrap();
781+
assert_eq!(name_dict.values().len(), 2);
782+
assert_eq!(name_dict.len(), 3);
783+
784+
let resolved = arrow::compute::cast(&name_array, &DataType::Utf8)?;
785+
let strings = resolved
786+
.as_any()
787+
.downcast_ref::<StringArray>()
788+
.unwrap();
789+
assert_eq!(strings.value(0), "main");
790+
assert_eq!(strings.value(1), "foo");
791+
assert_eq!(strings.value(2), "main");
792+
793+
Ok(())
794+
}
795+
572796
#[test]
573797
fn test_placement_literal_key() {
574798
let func = GetFieldFunc::new();

0 commit comments

Comments
 (0)