Skip to content

Commit 18e4a0c

Browse files
authored
fix: derive custom nullability for spark map_from_entries (#19274)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #19161 - Part of #19144 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Spark `map_from_entries` now uses `return_field_from_args` to handle nullability <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? - Added new unit tests to cover the changes - Previous tests pass <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 96ddd55 commit 18e4a0c

1 file changed

Lines changed: 94 additions & 9 deletions

File tree

datafusion/spark/src/function/map/map_from_entries.rs

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::sync::Arc;
1920

2021
use crate::function::map::utils::{
21-
get_element_type, get_list_offsets, get_list_values,
22-
map_from_keys_values_offsets_nulls, map_type_from_key_value_types,
22+
get_list_offsets, get_list_values, map_from_keys_values_offsets_nulls,
23+
map_type_from_key_value_types,
2324
};
2425
use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray};
2526
use arrow::buffer::NullBuffer;
26-
use arrow::datatypes::DataType;
27+
use arrow::datatypes::{DataType, Field, FieldRef};
2728
use datafusion_common::utils::take_function_args;
28-
use datafusion_common::{exec_err, Result};
29-
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
29+
use datafusion_common::{exec_err, internal_err, Result};
30+
use datafusion_expr::{
31+
ColumnarValue, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility,
32+
};
3033
use datafusion_functions::utils::make_scalar_function;
3134

3235
/// Spark-compatible `map_from_entries` expression
@@ -63,9 +66,28 @@ impl ScalarUDFImpl for MapFromEntries {
6366
&self.signature
6467
}
6568

66-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67-
let [entries_type] = take_function_args("map_from_entries", arg_types)?;
68-
let entries_element_type = get_element_type(entries_type)?;
69+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70+
internal_err!("return_field_from_args should be used instead")
71+
}
72+
73+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
74+
let [entries_field] = args.arg_fields else {
75+
return exec_err!("map_from_entries: expected one argument");
76+
};
77+
78+
let (entries_element_field, entries_element_type) =
79+
match entries_field.data_type() {
80+
DataType::List(field)
81+
| DataType::LargeList(field)
82+
| DataType::FixedSizeList(field, _) => {
83+
Ok((field.as_ref(), field.data_type()))
84+
}
85+
wrong_type => exec_err!(
86+
"map_from_entries: expected array<struct<key, value>>, got {:?}",
87+
wrong_type
88+
),
89+
}?;
90+
6991
let (keys_type, values_type) = match entries_element_type {
7092
DataType::Struct(fields) if fields.len() == 2 => {
7193
Ok((fields[0].data_type(), fields[1].data_type()))
@@ -75,7 +97,11 @@ impl ScalarUDFImpl for MapFromEntries {
7597
wrong_type
7698
),
7799
}?;
78-
Ok(map_type_from_key_value_types(keys_type, values_type))
100+
101+
let map_type = map_type_from_key_value_types(keys_type, values_type);
102+
let nullable = entries_field.is_nullable() || entries_element_field.is_nullable();
103+
104+
Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
79105
}
80106

81107
fn invoke_with_args(
@@ -131,3 +157,62 @@ fn map_from_entries_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
131157
res_nulls.as_ref(),
132158
)
133159
}
160+
161+
#[cfg(test)]
162+
mod tests {
163+
use super::*;
164+
use arrow::datatypes::Fields;
165+
use datafusion_expr::ReturnFieldArgs;
166+
167+
fn make_entries_field(array_nullable: bool, element_nullable: bool) -> FieldRef {
168+
let struct_type = DataType::Struct(Fields::from(vec![
169+
Field::new("key", DataType::Int32, false),
170+
Field::new("value", DataType::Utf8, true),
171+
]));
172+
Arc::new(Field::new(
173+
"entries",
174+
DataType::List(Arc::new(Field::new("item", struct_type, element_nullable))),
175+
array_nullable,
176+
))
177+
}
178+
179+
#[test]
180+
fn test_map_from_entries_nullability_matches_input() {
181+
let func = MapFromEntries::new();
182+
let expected_type =
183+
map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
184+
185+
// Non-nullable array and elements => non-nullable result
186+
let non_nullable_field = make_entries_field(false, false);
187+
let result = func
188+
.return_field_from_args(ReturnFieldArgs {
189+
arg_fields: &[Arc::clone(&non_nullable_field)],
190+
scalar_arguments: &[None],
191+
})
192+
.expect("should infer field");
193+
assert!(!result.is_nullable());
194+
assert_eq!(result.data_type(), &expected_type);
195+
196+
// Nullable elements should make result nullable even if array is non-nullable
197+
let element_nullable_field = make_entries_field(false, true);
198+
let result = func
199+
.return_field_from_args(ReturnFieldArgs {
200+
arg_fields: &[Arc::clone(&element_nullable_field)],
201+
scalar_arguments: &[None],
202+
})
203+
.expect("should infer field");
204+
assert!(result.is_nullable());
205+
assert_eq!(result.data_type(), &expected_type);
206+
207+
// Nullable array should also yield nullable result
208+
let array_nullable_field = make_entries_field(true, false);
209+
let result = func
210+
.return_field_from_args(ReturnFieldArgs {
211+
arg_fields: &[Arc::clone(&array_nullable_field)],
212+
scalar_arguments: &[None],
213+
})
214+
.expect("should infer field");
215+
assert!(result.is_nullable());
216+
assert_eq!(result.data_type(), &expected_type);
217+
}
218+
}

0 commit comments

Comments
 (0)