Skip to content

Commit 3e2f1f8

Browse files
committed
functions: Add dict support for get field (apache#21115)
1 parent 782b19c commit 3e2f1f8

1 file changed

Lines changed: 217 additions & 4 deletions

File tree

datafusion/functions/src/core/getfield.rs

Lines changed: 217 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@ use std::any::Any;
1919
use std::sync::Arc;
2020

2121
use arrow::array::{
22-
Array, BooleanArray, Capacities, MutableArrayData, Scalar, make_array,
23-
make_comparator,
22+
Array, BooleanArray, Capacities, DictionaryArray, MutableArrayData, Scalar,
23+
make_array, make_comparator,
2424
};
2525
use arrow::compute::SortOptions;
26-
use arrow::datatypes::{DataType, Field, FieldRef};
26+
use arrow::datatypes::{
27+
DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
28+
UInt16Type, UInt32Type, UInt64Type,
29+
};
2730
use arrow_buffer::NullBuffer;
2831

2932
use datafusion_common::cast::{as_map_array, as_struct_array};
3033
use datafusion_common::{
31-
Result, ScalarValue, exec_err, internal_err, plan_datafusion_err,
34+
Result, ScalarValue, exec_datafusion_err, exec_err, internal_datafusion_err,
35+
internal_err, plan_datafusion_err,
3236
};
3337
use datafusion_expr::expr::ScalarFunction;
3438
use datafusion_expr::simplify::ExprSimplifyResult;
@@ -199,6 +203,52 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
199203
let string_value = name.try_as_str().flatten().map(|s| s.to_string());
200204

201205
match (array.data_type(), name, string_value) {
206+
// Dictionary-encoded struct: extract the field from the dictionary's
207+
// values (the deduplicated struct array) and rebuild a dictionary with
208+
// the same keys. This preserves dictionary encoding without expanding.
209+
(DataType::Dictionary(key_type, value_type), _, Some(field_name))
210+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
211+
{
212+
// Downcast to DictionaryArray to access keys and values without
213+
// materializing the dictionary.
214+
macro_rules! extract_dict_field {
215+
($key_ty:ty) => {{
216+
let dict = array
217+
.as_any()
218+
.downcast_ref::<DictionaryArray<$key_ty>>()
219+
.ok_or_else(|| {
220+
internal_datafusion_err!(
221+
"Failed to downcast dictionary with key type {key_type}"
222+
)
223+
})?;
224+
let values_struct = as_struct_array(dict.values())?;
225+
let field_col =
226+
values_struct.column_by_name(&field_name).ok_or_else(|| {
227+
exec_datafusion_err!(
228+
"Field {field_name} not found in dictionary struct"
229+
)
230+
})?;
231+
// Rebuild dictionary: same keys, extracted field as values.
232+
let new_dict = DictionaryArray::<$key_ty>::try_new(
233+
dict.keys().clone(),
234+
Arc::clone(field_col),
235+
)?;
236+
Ok(ColumnarValue::Array(Arc::new(new_dict)))
237+
}};
238+
}
239+
240+
match key_type.as_ref() {
241+
DataType::Int8 => extract_dict_field!(Int8Type),
242+
DataType::Int16 => extract_dict_field!(Int16Type),
243+
DataType::Int32 => extract_dict_field!(Int32Type),
244+
DataType::Int64 => extract_dict_field!(Int64Type),
245+
DataType::UInt8 => extract_dict_field!(UInt8Type),
246+
DataType::UInt16 => extract_dict_field!(UInt16Type),
247+
DataType::UInt32 => extract_dict_field!(UInt32Type),
248+
DataType::UInt64 => extract_dict_field!(UInt64Type),
249+
other => exec_err!("Unsupported dictionary key type: {other}"),
250+
}
251+
}
202252
(DataType::Map(_, _), ScalarValue::List(arr), _) => {
203253
let key_array: Arc<dyn Array> = arr;
204254
process_map_array(&array, key_array)
@@ -338,6 +388,42 @@ impl ScalarUDFImpl for GetFieldFunc {
338388
}
339389
}
340390
}
391+
// Dictionary-encoded struct: resolve the child field from
392+
// the underlying struct, then wrap the result back in the
393+
// same Dictionary type so the promised type matches execution.
394+
DataType::Dictionary(key_type, value_type)
395+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
396+
{
397+
let DataType::Struct(fields) = value_type.as_ref() else {
398+
unreachable!()
399+
};
400+
let field_name = sv
401+
.as_ref()
402+
.and_then(|sv| {
403+
sv.try_as_str().flatten().filter(|s| !s.is_empty())
404+
})
405+
.ok_or_else(|| {
406+
exec_datafusion_err!("Field name must be a non-empty string")
407+
})?;
408+
409+
let child_field = fields
410+
.iter()
411+
.find(|f| f.name() == field_name)
412+
.ok_or_else(|| {
413+
plan_datafusion_err!("Field {field_name} not found in struct")
414+
})?;
415+
416+
let dict_type = DataType::Dictionary(
417+
key_type.clone(),
418+
Box::new(child_field.data_type().clone()),
419+
);
420+
let mut new_field =
421+
child_field.as_ref().clone().with_data_type(dict_type);
422+
if current_field.is_nullable() {
423+
new_field = new_field.with_nullable(true);
424+
}
425+
current_field = Arc::new(new_field);
426+
}
341427
DataType::Struct(fields) => {
342428
let field_name = sv
343429
.as_ref()
@@ -569,6 +655,133 @@ mod tests {
569655
Ok(())
570656
}
571657

658+
#[test]
659+
fn test_get_field_dict_encoded_struct() -> Result<()> {
660+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
661+
use arrow::datatypes::UInt32Type;
662+
663+
let names = Arc::new(StringArray::from(vec!["main", "foo", "bar"])) as ArrayRef;
664+
let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
665+
666+
let struct_fields: Fields = vec![
667+
Field::new("name", DataType::Utf8, false),
668+
Field::new("id", DataType::Int32, false),
669+
]
670+
.into();
671+
672+
let values_struct =
673+
Arc::new(StructArray::new(struct_fields, vec![names, ids], None)) as ArrayRef;
674+
675+
let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]);
676+
let dict = DictionaryArray::<UInt32Type>::try_new(keys, values_struct)?;
677+
678+
let base = ColumnarValue::Array(Arc::new(dict));
679+
let key = ScalarValue::Utf8(Some("name".to_string()));
680+
681+
let result = extract_single_field(base, key)?;
682+
let result_array = result.into_array(5)?;
683+
684+
assert!(
685+
matches!(result_array.data_type(), DataType::Dictionary(_, _)),
686+
"expected dictionary output, got {:?}",
687+
result_array.data_type()
688+
);
689+
690+
let result_dict = result_array
691+
.as_any()
692+
.downcast_ref::<DictionaryArray<UInt32Type>>()
693+
.unwrap();
694+
assert_eq!(result_dict.values().len(), 3);
695+
assert_eq!(result_dict.len(), 5);
696+
697+
let resolved = arrow::compute::cast(&result_array, &DataType::Utf8)?;
698+
let string_arr = resolved.as_any().downcast_ref::<StringArray>().unwrap();
699+
assert_eq!(string_arr.value(0), "main");
700+
assert_eq!(string_arr.value(1), "foo");
701+
assert_eq!(string_arr.value(2), "bar");
702+
assert_eq!(string_arr.value(3), "main");
703+
assert_eq!(string_arr.value(4), "foo");
704+
705+
Ok(())
706+
}
707+
708+
#[test]
709+
fn test_get_field_nested_dict_struct() -> Result<()> {
710+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
711+
use arrow::datatypes::UInt32Type;
712+
713+
let func_names = Arc::new(StringArray::from(vec!["main", "foo"])) as ArrayRef;
714+
let func_files = Arc::new(StringArray::from(vec!["main.c", "foo.c"])) as ArrayRef;
715+
let func_fields: Fields = vec![
716+
Field::new("name", DataType::Utf8, false),
717+
Field::new("file", DataType::Utf8, false),
718+
]
719+
.into();
720+
let func_struct = Arc::new(StructArray::new(
721+
func_fields.clone(),
722+
vec![func_names, func_files],
723+
None,
724+
)) as ArrayRef;
725+
let func_dict = Arc::new(DictionaryArray::<UInt32Type>::try_new(
726+
UInt32Array::from(vec![0u32, 1, 0]),
727+
func_struct,
728+
)?) as ArrayRef;
729+
730+
let line_nums = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef;
731+
let line_fields: Fields = vec![
732+
Field::new("num", DataType::Int32, false),
733+
Field::new(
734+
"function",
735+
DataType::Dictionary(
736+
Box::new(DataType::UInt32),
737+
Box::new(DataType::Struct(func_fields)),
738+
),
739+
false,
740+
),
741+
]
742+
.into();
743+
let line_struct = StructArray::new(line_fields, vec![line_nums, func_dict], None);
744+
745+
let base = ColumnarValue::Array(Arc::new(line_struct));
746+
747+
let func_result =
748+
extract_single_field(base, ScalarValue::Utf8(Some("function".to_string())))?;
749+
750+
let func_array = func_result.into_array(3)?;
751+
assert!(
752+
matches!(func_array.data_type(), DataType::Dictionary(_, _)),
753+
"expected dictionary for function, got {:?}",
754+
func_array.data_type()
755+
);
756+
757+
let name_result = extract_single_field(
758+
ColumnarValue::Array(func_array),
759+
ScalarValue::Utf8(Some("name".to_string())),
760+
)?;
761+
let name_array = name_result.into_array(3)?;
762+
763+
assert!(
764+
matches!(name_array.data_type(), DataType::Dictionary(_, _)),
765+
"expected dictionary for name, got {:?}",
766+
name_array.data_type()
767+
);
768+
769+
let name_dict = name_array
770+
.as_any()
771+
.downcast_ref::<DictionaryArray<UInt32Type>>()
772+
.unwrap();
773+
assert_eq!(name_dict.values().len(), 2);
774+
assert_eq!(name_dict.len(), 3);
775+
776+
let resolved = arrow::compute::cast(&name_array, &DataType::Utf8)?;
777+
let strings = resolved.as_any().downcast_ref::<StringArray>().unwrap();
778+
assert_eq!(strings.value(0), "main");
779+
assert_eq!(strings.value(1), "foo");
780+
assert_eq!(strings.value(2), "main");
781+
782+
Ok(())
783+
}
784+
572785
#[test]
573786
fn test_placement_literal_key() {
574787
let func = GetFieldFunc::new();

0 commit comments

Comments
 (0)