Skip to content

Commit 7c1604e

Browse files
Return List of Nulls for Null input
1 parent 2fc0ed0 commit 7c1604e

2 files changed

Lines changed: 57 additions & 31 deletions

File tree

  • datafusion

datafusion/spark/src/function/array/slice.rs

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, ArrayRef, Int64Builder};
18+
use arrow::array::{Array, ArrayData, ArrayRef, Int64Builder, ListArray};
1919
use arrow::datatypes::{DataType, Field, FieldRef};
2020
use datafusion_common::cast::{as_int64_array, as_list_array};
2121
use datafusion_common::utils::ListCoercion;
22-
use datafusion_common::{
23-
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
24-
};
22+
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
2523
use datafusion_expr::{
2624
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
2725
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -80,21 +78,26 @@ impl ScalarUDFImpl for SparkSlice {
8078
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
8179
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
8280

83-
Ok(Arc::new(Field::new(
84-
"slice",
85-
args.arg_fields[0].data_type().clone(),
86-
nullable,
87-
)))
81+
let data_type = match args.arg_fields[0].data_type() {
82+
DataType::Null => {
83+
DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
84+
}
85+
dt => dt.clone(),
86+
};
87+
88+
Ok(Arc::new(Field::new("slice", data_type, nullable)))
8889
}
8990

9091
fn invoke_with_args(
9192
&self,
9293
mut func_args: ScalarFunctionArgs,
9394
) -> Result<ColumnarValue> {
94-
if func_args.args[0].data_type() == DataType::Null
95-
&& let Some(result) = check_null_types(&func_args.args[0])
96-
{
97-
return Ok(result);
95+
if func_args.args[0].data_type() == DataType::Null {
96+
let len = match &func_args.args[0] {
97+
ColumnarValue::Array(a) => a.len(),
98+
ColumnarValue::Scalar(_) => func_args.number_rows,
99+
};
100+
return Ok(ColumnarValue::Array(list_null_array(len)));
98101
}
99102

100103
let array_len = func_args
@@ -131,14 +134,9 @@ impl ScalarUDFImpl for SparkSlice {
131134
}
132135
}
133136

134-
fn check_null_types(cv: &ColumnarValue) -> Option<ColumnarValue> {
135-
match cv {
136-
ColumnarValue::Scalar(ScalarValue::Null) => {
137-
Some(ColumnarValue::create_null_array(1))
138-
}
139-
ColumnarValue::Array(_) => Some(cv.clone()),
140-
_ => None,
141-
}
137+
fn list_null_array(len: usize) -> ArrayRef {
138+
let list_type = DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)));
139+
Arc::new(ListArray::from(ArrayData::new_null(&list_type, len)))
142140
}
143141

144142
fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
@@ -188,9 +186,30 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
188186
mod tests {
189187
use super::*;
190188
use arrow::array::NullArray;
191-
use arrow::datatypes::DataType::List;
192189
use arrow::datatypes::Field;
193190
use datafusion_common::ScalarValue;
191+
use datafusion_common::cast::as_list_array;
192+
use datafusion_expr::ReturnFieldArgs;
193+
194+
#[test]
195+
fn test_spark_slice_function_when_input_is_null() {
196+
let slice = SparkSlice::new();
197+
let arg_fields: Vec<Arc<Field>> = vec![
198+
Arc::new(Field::new("a", DataType::Null, true)),
199+
Arc::new(Field::new("s", DataType::Int64, true)),
200+
Arc::new(Field::new("l", DataType::Int64, true)),
201+
];
202+
let out = slice
203+
.return_field_from_args(ReturnFieldArgs {
204+
arg_fields: &arg_fields,
205+
scalar_arguments: &[],
206+
})
207+
.unwrap();
208+
assert_eq!(
209+
out.data_type(),
210+
&DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
211+
);
212+
}
194213

195214
#[test]
196215
fn test_spark_slice_function_when_input_array_is_null() {
@@ -202,21 +221,23 @@ mod tests {
202221

203222
let args = ScalarFunctionArgs {
204223
args: input_args,
205-
arg_fields: vec![Arc::new(Field::new(
206-
"item",
207-
List(FieldRef::new(Field::new("f", DataType::Int64, true))),
208-
false,
209-
))],
224+
arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))],
210225
number_rows: 1,
211226
return_field: Arc::new(Field::new(
212-
"item",
213-
List(FieldRef::new(Field::new_list_field(DataType::Int64, true))),
214-
false,
227+
"slice",
228+
DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))),
229+
true,
215230
)),
216231
config_options: Arc::new(Default::default()),
217232
};
218233
let slice = SparkSlice::new();
219234
let result = slice.invoke_with_args(args).unwrap();
220-
assert_eq!(*result.to_array(1).unwrap(), *Arc::new(NullArray::new(1)));
235+
let arr = result.to_array(1).unwrap();
236+
let list = as_list_array(&arr).unwrap();
237+
assert_eq!(
238+
arr.data_type(),
239+
&DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
240+
);
241+
assert!(list.is_null(0));
221242
}
222243
}

datafusion/sqllogictest/test_files/spark/array/slice.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,8 @@ query ?
132132
SELECT slice(slice(NULL, 1, 2), 1, 2)
133133
----
134134
NULL
135+
136+
query ?
137+
SELECT slice(slice(make_array(NULL), 1, 2), 1, 2)
138+
----
139+
[NULL]

0 commit comments

Comments
 (0)