diff --git a/datafusion/spark/src/function/array/slice.rs b/datafusion/spark/src/function/array/slice.rs index d3884d8f3b902..bcd10a1bf7d79 100644 --- a/datafusion/spark/src/function/array/slice.rs +++ b/datafusion/spark/src/function/array/slice.rs @@ -19,7 +19,9 @@ use arrow::array::{Array, ArrayRef, Int64Builder}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::{as_int64_array, as_list_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -78,17 +80,28 @@ impl ScalarUDFImpl for SparkSlice { fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - Ok(Arc::new(Field::new( - "slice", - args.arg_fields[0].data_type().clone(), - nullable, - ))) + let data_type = match args.arg_fields[0].data_type() { + DataType::Null => { + DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))) + } + dt => dt.clone(), + }; + + Ok(Arc::new(Field::new("slice", data_type, nullable))) } fn invoke_with_args( &self, mut func_args: ScalarFunctionArgs, ) -> Result { + if func_args.args[0].data_type() == DataType::Null { + return Ok(ColumnarValue::Scalar(ScalarValue::new_null_list( + DataType::Null, + true, + 1, + ))); + } + let array_len = func_args .args .iter() @@ -165,3 +178,63 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> { Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish()))) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + use datafusion_common::cast::as_list_array; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_spark_slice_function_when_input_is_null() { + let slice = SparkSlice::new(); + let arg_fields: Vec> = vec![ + Arc::new(Field::new("a", DataType::Null, true)), + Arc::new(Field::new("s", DataType::Int64, true)), + Arc::new(Field::new("l", DataType::Int64, true)), + ]; + let out = slice + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[], + }) + .unwrap(); + assert_eq!( + out.data_type(), + &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))) + ); + } + + #[test] + fn test_spark_slice_function_when_input_array_is_null() { + let input_args = vec![ + ColumnarValue::Array(Arc::new(NullArray::new(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ]; + + let args = ScalarFunctionArgs { + args: input_args, + arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))], + number_rows: 1, + return_field: Arc::new(Field::new( + "slice", + DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))), + true, + )), + config_options: Arc::new(Default::default()), + }; + let slice = SparkSlice::new(); + let result = slice.invoke_with_args(args).unwrap(); + let arr = result.to_array(1).unwrap(); + let list = as_list_array(&arr).unwrap(); + assert_eq!( + arr.data_type(), + &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))) + ); + assert!(list.is_null(0)); + } +} diff --git a/datafusion/sqllogictest/test_files/spark/array/slice.slt b/datafusion/sqllogictest/test_files/spark/array/slice.slt index 4aba076aba6ba..6dfc1c0c6d0bf 100644 --- a/datafusion/sqllogictest/test_files/spark/array/slice.slt +++ b/datafusion/sqllogictest/test_files/spark/array/slice.slt @@ -114,3 +114,26 @@ query ? SELECT slice([1, 2, 3, 4], CAST('2' AS INT), 4); ---- [2, 3, 4] + +query ? +SELECT slice(column1, column2, column3) +FROM VALUES +(NULL, 1, 2), +(NULL, 1, -2), +(NULL, -1, 2), +(NULL, 0, 2); +---- +NULL +NULL +NULL +NULL + +query ? +SELECT slice(slice(NULL, 1, 2), 1, 2) +---- +NULL + +query ? +SELECT slice(slice(make_array(NULL), 1, 2), 1, 2) +---- +[NULL]