1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: { Array , ArrayRef , Int64Builder } ;
18+ use arrow:: array:: { Array , ArrayData , ArrayRef , Int64Builder , ListArray } ;
1919use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2020use datafusion_common:: cast:: { as_int64_array, as_list_array} ;
2121use datafusion_common:: utils:: ListCoercion ;
22- use datafusion_common:: {
23- Result , ScalarValue , exec_err, internal_err, utils:: take_function_args,
24- } ;
22+ use datafusion_common:: { Result , exec_err, internal_err, utils:: take_function_args} ;
2523use datafusion_expr:: {
2624 ArrayFunctionArgument , ArrayFunctionSignature , ColumnarValue , ReturnFieldArgs ,
2725 ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature , Volatility ,
@@ -80,21 +78,26 @@ impl ScalarUDFImpl for SparkSlice {
8078 fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
8179 let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) ) ;
8280
83- Ok ( Arc :: new ( Field :: new (
84- "slice" ,
85- args. arg_fields [ 0 ] . data_type ( ) . clone ( ) ,
86- nullable,
87- ) ) )
81+ let data_type = match args. arg_fields [ 0 ] . data_type ( ) {
82+ DataType :: Null => {
83+ DataType :: List ( Arc :: new ( Field :: new_list_field ( DataType :: Null , true ) ) )
84+ }
85+ dt => dt. clone ( ) ,
86+ } ;
87+
88+ Ok ( Arc :: new ( Field :: new ( "slice" , data_type, nullable) ) )
8889 }
8990
9091 fn invoke_with_args (
9192 & self ,
9293 mut func_args : ScalarFunctionArgs ,
9394 ) -> Result < ColumnarValue > {
94- if func_args. args [ 0 ] . data_type ( ) == DataType :: Null
95- && let Some ( result) = check_null_types ( & func_args. args [ 0 ] )
96- {
97- return Ok ( result) ;
95+ if func_args. args [ 0 ] . data_type ( ) == DataType :: Null {
96+ let len = match & func_args. args [ 0 ] {
97+ ColumnarValue :: Array ( a) => a. len ( ) ,
98+ ColumnarValue :: Scalar ( _) => func_args. number_rows ,
99+ } ;
100+ return Ok ( ColumnarValue :: Array ( list_null_array ( len) ) ) ;
98101 }
99102
100103 let array_len = func_args
@@ -131,14 +134,9 @@ impl ScalarUDFImpl for SparkSlice {
131134 }
132135}
133136
134- fn check_null_types ( cv : & ColumnarValue ) -> Option < ColumnarValue > {
135- match cv {
136- ColumnarValue :: Scalar ( ScalarValue :: Null ) => {
137- Some ( ColumnarValue :: create_null_array ( 1 ) )
138- }
139- ColumnarValue :: Array ( _) => Some ( cv. clone ( ) ) ,
140- _ => None ,
141- }
137+ fn list_null_array ( len : usize ) -> ArrayRef {
138+ let list_type = DataType :: List ( Arc :: new ( Field :: new_list_field ( DataType :: Null , true ) ) ) ;
139+ Arc :: new ( ListArray :: from ( ArrayData :: new_null ( & list_type, len) ) )
142140}
143141
144142fn calculate_start_end ( args : & [ ArrayRef ] ) -> Result < ( ArrayRef , ArrayRef ) > {
@@ -188,9 +186,30 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
188186mod tests {
189187 use super :: * ;
190188 use arrow:: array:: NullArray ;
191- use arrow:: datatypes:: DataType :: List ;
192189 use arrow:: datatypes:: Field ;
193190 use datafusion_common:: ScalarValue ;
191+ use datafusion_common:: cast:: as_list_array;
192+ use datafusion_expr:: ReturnFieldArgs ;
193+
194+ #[ test]
195+ fn test_spark_slice_function_when_input_is_null ( ) {
196+ let slice = SparkSlice :: new ( ) ;
197+ let arg_fields: Vec < Arc < Field > > = vec ! [
198+ Arc :: new( Field :: new( "a" , DataType :: Null , true ) ) ,
199+ Arc :: new( Field :: new( "s" , DataType :: Int64 , true ) ) ,
200+ Arc :: new( Field :: new( "l" , DataType :: Int64 , true ) ) ,
201+ ] ;
202+ let out = slice
203+ . return_field_from_args ( ReturnFieldArgs {
204+ arg_fields : & arg_fields,
205+ scalar_arguments : & [ ] ,
206+ } )
207+ . unwrap ( ) ;
208+ assert_eq ! (
209+ out. data_type( ) ,
210+ & DataType :: List ( Arc :: new( Field :: new_list_field( DataType :: Null , true ) ) )
211+ ) ;
212+ }
194213
195214 #[ test]
196215 fn test_spark_slice_function_when_input_array_is_null ( ) {
@@ -202,21 +221,23 @@ mod tests {
202221
203222 let args = ScalarFunctionArgs {
204223 args : input_args,
205- arg_fields : vec ! [ Arc :: new( Field :: new(
206- "item" ,
207- List ( FieldRef :: new( Field :: new( "f" , DataType :: Int64 , true ) ) ) ,
208- false ,
209- ) ) ] ,
224+ arg_fields : vec ! [ Arc :: new( Field :: new( "item" , DataType :: Null , true ) ) ] ,
210225 number_rows : 1 ,
211226 return_field : Arc :: new ( Field :: new (
212- "item " ,
213- List ( FieldRef :: new ( Field :: new_list_field ( DataType :: Int64 , true ) ) ) ,
214- false ,
227+ "slice " ,
228+ DataType :: List ( Arc :: new ( Field :: new_list_field ( DataType :: Null , true ) ) ) ,
229+ true ,
215230 ) ) ,
216231 config_options : Arc :: new ( Default :: default ( ) ) ,
217232 } ;
218233 let slice = SparkSlice :: new ( ) ;
219234 let result = slice. invoke_with_args ( args) . unwrap ( ) ;
220- assert_eq ! ( * result. to_array( 1 ) . unwrap( ) , * Arc :: new( NullArray :: new( 1 ) ) ) ;
235+ let arr = result. to_array ( 1 ) . unwrap ( ) ;
236+ let list = as_list_array ( & arr) . unwrap ( ) ;
237+ assert_eq ! (
238+ arr. data_type( ) ,
239+ & DataType :: List ( Arc :: new( Field :: new_list_field( DataType :: Null , true ) ) )
240+ ) ;
241+ assert ! ( list. is_null( 0 ) ) ;
221242 }
222243}
0 commit comments