Skip to content

Commit 6245ba6

Browse files
fix: Fix Spark slice function Null type to GenericListArray casting issue
1 parent c17c87c commit 6245ba6

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

  • datafusion

datafusion/spark/src/function/array/slice.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use arrow::array::{Array, ArrayRef, Int64Builder};
1919
use arrow::datatypes::{DataType, Field, FieldRef};
2020
use datafusion_common::cast::{as_int64_array, as_list_array};
2121
use datafusion_common::utils::ListCoercion;
22-
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
22+
use datafusion_common::{
23+
DataFusionError, Result, exec_err, internal_err, utils::take_function_args,
24+
};
2325
use datafusion_expr::{
2426
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
2527
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -89,6 +91,10 @@ impl ScalarUDFImpl for SparkSlice {
8991
&self,
9092
mut func_args: ScalarFunctionArgs,
9193
) -> Result<ColumnarValue> {
94+
if func_args.args[0].data_type() == DataType::Null {
95+
return Ok::<ColumnarValue, DataFusionError>(func_args.args[0].clone());
96+
};
97+
9298
let array_len = func_args
9399
.args
94100
.iter()
@@ -165,3 +171,40 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
165171

166172
Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
167173
}
174+
175+
#[cfg(test)]
176+
mod tests {
177+
use super::*;
178+
use arrow::array::NullArray;
179+
use arrow::datatypes::DataType::List;
180+
use arrow::datatypes::Field;
181+
use datafusion_common::ScalarValue;
182+
183+
#[test]
184+
fn test_spark_slice_function_when_input_array_is_null() {
185+
let input_args = vec![
186+
ColumnarValue::Array(Arc::new(NullArray::new(1))),
187+
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
188+
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
189+
];
190+
191+
let args = ScalarFunctionArgs {
192+
args: input_args.to_owned(),
193+
arg_fields: vec![Arc::new(Field::new(
194+
"item",
195+
List(FieldRef::new(Field::new("", DataType::Int64, true))),
196+
false,
197+
))],
198+
number_rows: 0,
199+
return_field: Arc::new(Field::new(
200+
"item",
201+
List(FieldRef::new(Field::new_list_field(DataType::Int64, true))),
202+
false,
203+
)),
204+
config_options: Arc::new(Default::default()),
205+
};
206+
let slice = SparkSlice::new();
207+
let result = slice.invoke_with_args(args).unwrap();
208+
assert!(result.to_array(1).unwrap() == Arc::new(NullArray::new(1)));
209+
}
210+
}

datafusion/sqllogictest/test_files/spark/array/slice.slt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,21 @@ query ?
114114
SELECT slice([1, 2, 3, 4], CAST('2' AS INT), 4);
115115
----
116116
[2, 3, 4]
117+
118+
query ?
119+
SELECT slice(column1, column2, column3)
120+
FROM VALUES
121+
(NULL, 1, 2),
122+
(NULL, 1, -2),
123+
(NULL, -1, 2),
124+
(NULL, 0, 2);
125+
----
126+
NULL
127+
NULL
128+
NULL
129+
NULL
130+
131+
query ?
132+
SELECT slice(slice(NULL, 1, 2), 1, 2)
133+
----
134+
NULL

0 commit comments

Comments
 (0)