1818use arrow:: datatypes:: DataType ;
1919use datafusion_expr:: { ColumnarValue , Documentation , ScalarFunctionArgs } ;
2020
21- use arrow:: compute:: kernels:: cmp:: eq;
2221use arrow:: compute:: kernels:: nullif:: nullif;
2322use datafusion_common:: { Result , ScalarValue , utils:: take_function_args} ;
2423use datafusion_expr:: { ScalarUDFImpl , Signature , Volatility } ;
2524use datafusion_macros:: user_doc;
25+ use datafusion_physical_expr_common:: datum:: compare_with_eq;
2626
2727#[ user_doc(
2828 doc_section( label = "Conditional Functions" ) ,
@@ -111,25 +111,29 @@ impl ScalarUDFImpl for NullIfFunc {
111111/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
112112fn nullif_func ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
113113 let [ lhs, rhs] = take_function_args ( "nullif" , args) ?;
114+ let is_nested = lhs. data_type ( ) . is_nested ( ) ;
114115
115116 match ( lhs, rhs) {
116117 ( ColumnarValue :: Array ( lhs) , ColumnarValue :: Scalar ( rhs) ) => {
117118 let rhs = rhs. to_scalar ( ) ?;
118- let array = nullif ( lhs, & eq ( & lhs, & rhs) ?) ?;
119+ let eq_array = compare_with_eq ( lhs, & rhs, is_nested) ?;
120+ let array = nullif ( lhs, & eq_array) ?;
119121
120122 Ok ( ColumnarValue :: Array ( array) )
121123 }
122124 ( ColumnarValue :: Array ( lhs) , ColumnarValue :: Array ( rhs) ) => {
123- let array = nullif ( lhs, & eq ( & lhs, & rhs) ?) ?;
125+ let eq_array = compare_with_eq ( lhs, rhs, is_nested) ?;
126+ let array = nullif ( lhs, & eq_array) ?;
124127 Ok ( ColumnarValue :: Array ( array) )
125128 }
126129 ( ColumnarValue :: Scalar ( lhs) , ColumnarValue :: Array ( rhs) ) => {
127130 let lhs_s = lhs. to_scalar ( ) ?;
128131 let lhs_a = lhs. to_array_of_size ( rhs. len ( ) ) ?;
132+ let eq_array = compare_with_eq ( & lhs_s, rhs, is_nested) ?;
129133 let array = nullif (
130134 // nullif in arrow-select does not support Datum, so we need to convert to array
131135 lhs_a. as_ref ( ) ,
132- & eq ( & lhs_s , & rhs ) ? ,
136+ & eq_array ,
133137 ) ?;
134138 Ok ( ColumnarValue :: Array ( array) )
135139 }
@@ -148,7 +152,12 @@ fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
148152mod tests {
149153 use std:: sync:: Arc ;
150154
151- use arrow:: array:: * ;
155+ use arrow:: {
156+ array:: * ,
157+ buffer:: NullBuffer ,
158+ datatypes:: { Field , Fields , Int64Type } ,
159+ } ;
160+ use datafusion_common:: DataFusionError ;
152161
153162 use super :: * ;
154163
@@ -251,6 +260,104 @@ mod tests {
251260 Ok ( ( ) )
252261 }
253262
263+ #[ test]
264+ fn nullif_struct ( ) -> Result < ( ) > {
265+ let fields = Fields :: from ( vec ! [
266+ Field :: new( "a" , DataType :: Int64 , true ) ,
267+ Field :: new( "b" , DataType :: Utf8 , true ) ,
268+ ] ) ;
269+
270+ let lhs_a = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None ] ) ) ;
271+ let lhs_b = Arc :: new ( StringArray :: from ( vec ! [ Some ( "1" ) , Some ( "2" ) , None ] ) ) ;
272+ let lhs_nulls = Some ( NullBuffer :: from ( vec ! [ true , true , false ] ) ) ;
273+ let lhs = ColumnarValue :: Array ( Arc :: new ( StructArray :: new (
274+ fields. clone ( ) ,
275+ vec ! [ lhs_a, lhs_b] ,
276+ lhs_nulls,
277+ ) ) ) ;
278+
279+ let rhs_a = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 9 ) , None ] ) ) ;
280+ let rhs_b = Arc :: new ( StringArray :: from ( vec ! [ Some ( "1" ) , Some ( "2" ) , None ] ) ) ;
281+ let rhs_nulls = Some ( NullBuffer :: from ( vec ! [ true , true , false ] ) ) ;
282+ let rhs = ColumnarValue :: Array ( Arc :: new ( StructArray :: new (
283+ fields. clone ( ) ,
284+ vec ! [ rhs_a, rhs_b] ,
285+ rhs_nulls,
286+ ) ) ) ;
287+
288+ let result = nullif_func ( & [ lhs, rhs] ) ?;
289+ let result = result. into_array ( 0 ) . expect ( "Failed to convert to array" ) ;
290+
291+ let expected_arrays = vec ! [
292+ Arc :: new( Int64Array :: from( vec![ None , Some ( 2 ) , None ] ) ) as ArrayRef ,
293+ Arc :: new( StringArray :: from( vec![ None , Some ( "2" ) , None ] ) ) as ArrayRef ,
294+ ] ;
295+ let expected_nulls = NullBuffer :: from ( vec ! [ false , true , false ] ) ;
296+
297+ let expected = Arc :: new ( StructArray :: try_new (
298+ fields,
299+ expected_arrays,
300+ Some ( expected_nulls) ,
301+ ) ?) as ArrayRef ;
302+
303+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
304+
305+ Ok ( ( ) )
306+ }
307+
308+ #[ test]
309+ fn nullif_list ( ) -> Result < ( ) > {
310+ let lhs = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
311+ Some ( vec![ Some ( 1 ) , Some ( 2 ) ] ) ,
312+ Some ( vec![ Some ( 3 ) ] ) ,
313+ Some ( vec![ ] ) ,
314+ Some ( vec![ Some ( 5 ) , Some ( 6 ) , Some ( 7 ) ] ) ,
315+ None ,
316+ ] ) ) ;
317+ let lhs = ColumnarValue :: Array ( lhs) ;
318+
319+ let rhs = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
320+ Some ( vec![ Some ( 1 ) , Some ( 2 ) ] ) ,
321+ ] ) ) ;
322+ let rhs = ColumnarValue :: Scalar ( ScalarValue :: List ( rhs) ) ;
323+
324+ let result = nullif_func ( & [ lhs, rhs] ) ?;
325+ let result = result. into_array ( 0 ) . expect ( "Failed to convert to array" ) ;
326+
327+ let expected = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
328+ None ,
329+ Some ( vec![ Some ( 3 ) ] ) ,
330+ Some ( vec![ ] ) ,
331+ Some ( vec![ Some ( 5 ) , Some ( 6 ) , Some ( 7 ) ] ) ,
332+ None ,
333+ ] ) ) as ArrayRef ;
334+
335+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
336+
337+ Ok ( ( ) )
338+ }
339+
340+ #[ test]
341+ fn nullif_compare_nested_to_unnested ( ) -> Result < ( ) > {
342+ let lhs = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
343+ Some ( vec![ Some ( 1 ) , Some ( 2 ) ] ) ,
344+ Some ( vec![ Some ( 3 ) ] ) ,
345+ Some ( vec![ ] ) ,
346+ Some ( vec![ Some ( 5 ) , Some ( 6 ) , Some ( 7 ) ] ) ,
347+ None ,
348+ ] ) ) ;
349+ let lhs = ColumnarValue :: Array ( lhs) ;
350+
351+ let rhs = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 3 ) , None , None , None ] ) ) ;
352+ let rhs = ColumnarValue :: Array ( rhs) ;
353+
354+ let result = nullif_func ( & [ lhs, rhs] ) ;
355+
356+ assert ! ( matches!( result, Err ( DataFusionError :: ArrowError ( _, _) ) ) ) ;
357+
358+ Ok ( ( ) )
359+ }
360+
254361 #[ test]
255362 fn nullif_literal_first ( ) -> Result < ( ) > {
256363 let a = Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None , None , Some ( 3 ) , Some ( 4 ) ] ) ;
0 commit comments