1818use arrow:: datatypes:: DataType ;
1919use datafusion_expr:: { ColumnarValue , Documentation , ScalarFunctionArgs } ;
2020
21- use arrow:: compute:: kernels:: cmp:: eq;
2221use arrow:: compute:: kernels:: nullif:: nullif;
2322use datafusion_common:: { Result , ScalarValue , utils:: take_function_args} ;
2423use datafusion_expr:: { ScalarUDFImpl , Signature , Volatility } ;
2524use datafusion_macros:: user_doc;
25+ use datafusion_physical_expr_common:: datum:: compare_with_eq;
2626use std:: any:: Any ;
2727
2828#[ user_doc(
@@ -115,25 +115,29 @@ impl ScalarUDFImpl for NullIfFunc {
115115/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
116116fn nullif_func ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
117117 let [ lhs, rhs] = take_function_args ( "nullif" , args) ?;
118+ let is_nested = lhs. data_type ( ) . is_nested ( ) ;
118119
119120 match ( lhs, rhs) {
120121 ( ColumnarValue :: Array ( lhs) , ColumnarValue :: Scalar ( rhs) ) => {
121122 let rhs = rhs. to_scalar ( ) ?;
122- let array = nullif ( lhs, & eq ( & lhs, & rhs) ?) ?;
123+ let eq_array = compare_with_eq ( lhs, & rhs, is_nested) ?;
124+ let array = nullif ( lhs, & eq_array) ?;
123125
124126 Ok ( ColumnarValue :: Array ( array) )
125127 }
126128 ( ColumnarValue :: Array ( lhs) , ColumnarValue :: Array ( rhs) ) => {
127- let array = nullif ( lhs, & eq ( & lhs, & rhs) ?) ?;
129+ let eq_array = compare_with_eq ( lhs, rhs, is_nested) ?;
130+ let array = nullif ( lhs, & eq_array) ?;
128131 Ok ( ColumnarValue :: Array ( array) )
129132 }
130133 ( ColumnarValue :: Scalar ( lhs) , ColumnarValue :: Array ( rhs) ) => {
131134 let lhs_s = lhs. to_scalar ( ) ?;
132135 let lhs_a = lhs. to_array_of_size ( rhs. len ( ) ) ?;
136+ let eq_array = compare_with_eq ( & lhs_s, rhs, is_nested) ?;
133137 let array = nullif (
134138 // nullif in arrow-select does not support Datum, so we need to convert to array
135139 lhs_a. as_ref ( ) ,
136- & eq ( & lhs_s , & rhs ) ? ,
140+ & eq_array ,
137141 ) ?;
138142 Ok ( ColumnarValue :: Array ( array) )
139143 }
@@ -152,7 +156,12 @@ fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
152156mod tests {
153157 use std:: sync:: Arc ;
154158
155- use arrow:: array:: * ;
159+ use arrow:: {
160+ array:: * ,
161+ buffer:: NullBuffer ,
162+ datatypes:: { Field , Fields , Int64Type } ,
163+ } ;
164+ use datafusion_common:: DataFusionError ;
156165
157166 use super :: * ;
158167
@@ -255,6 +264,104 @@ mod tests {
255264 Ok ( ( ) )
256265 }
257266
267+ #[ test]
268+ fn nullif_struct ( ) -> Result < ( ) > {
269+ let fields = Fields :: from ( vec ! [
270+ Field :: new( "a" , DataType :: Int64 , true ) ,
271+ Field :: new( "b" , DataType :: Utf8 , true ) ,
272+ ] ) ;
273+
274+ let lhs_a = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None ] ) ) ;
275+ let lhs_b = Arc :: new ( StringArray :: from ( vec ! [ Some ( "1" ) , Some ( "2" ) , None ] ) ) ;
276+ let lhs_nulls = Some ( NullBuffer :: from ( vec ! [ true , true , false ] ) ) ;
277+ let lhs = ColumnarValue :: Array ( Arc :: new ( StructArray :: new (
278+ fields. clone ( ) ,
279+ vec ! [ lhs_a, lhs_b] ,
280+ lhs_nulls,
281+ ) ) ) ;
282+
283+ let rhs_a = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 9 ) , None ] ) ) ;
284+ let rhs_b = Arc :: new ( StringArray :: from ( vec ! [ Some ( "1" ) , Some ( "2" ) , None ] ) ) ;
285+ let rhs_nulls = Some ( NullBuffer :: from ( vec ! [ true , true , false ] ) ) ;
286+ let rhs = ColumnarValue :: Array ( Arc :: new ( StructArray :: new (
287+ fields. clone ( ) ,
288+ vec ! [ rhs_a, rhs_b] ,
289+ rhs_nulls,
290+ ) ) ) ;
291+
292+ let result = nullif_func ( & [ lhs, rhs] ) ?;
293+ let result = result. into_array ( 0 ) . expect ( "Failed to convert to array" ) ;
294+
295+ let expected_arrays = vec ! [
296+ Arc :: new( Int64Array :: from( vec![ None , Some ( 2 ) , None ] ) ) as ArrayRef ,
297+ Arc :: new( StringArray :: from( vec![ None , Some ( "2" ) , None ] ) ) as ArrayRef ,
298+ ] ;
299+ let expected_nulls = NullBuffer :: from ( vec ! [ false , true , false ] ) ;
300+
301+ let expected = Arc :: new ( StructArray :: try_new (
302+ fields,
303+ expected_arrays,
304+ Some ( expected_nulls) ,
305+ ) ?) as ArrayRef ;
306+
307+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
308+
309+ Ok ( ( ) )
310+ }
311+
312+ #[ test]
313+ fn nullif_list ( ) -> Result < ( ) > {
314+ let lhs = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
315+ Some ( vec![ Some ( 1 ) , Some ( 2 ) ] ) ,
316+ Some ( vec![ Some ( 3 ) ] ) ,
317+ Some ( vec![ ] ) ,
318+ Some ( vec![ Some ( 5 ) , Some ( 6 ) , Some ( 7 ) ] ) ,
319+ None ,
320+ ] ) ) ;
321+ let lhs = ColumnarValue :: Array ( lhs) ;
322+
323+ let rhs = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
324+ Some ( vec![ Some ( 1 ) , Some ( 2 ) ] ) ,
325+ ] ) ) ;
326+ let rhs = ColumnarValue :: Scalar ( ScalarValue :: List ( rhs) ) ;
327+
328+ let result = nullif_func ( & [ lhs, rhs] ) ?;
329+ let result = result. into_array ( 0 ) . expect ( "Failed to convert to array" ) ;
330+
331+ let expected = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
332+ None ,
333+ Some ( vec![ Some ( 3 ) ] ) ,
334+ Some ( vec![ ] ) ,
335+ Some ( vec![ Some ( 5 ) , Some ( 6 ) , Some ( 7 ) ] ) ,
336+ None ,
337+ ] ) ) as ArrayRef ;
338+
339+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
340+
341+ Ok ( ( ) )
342+ }
343+
344+ #[ test]
345+ fn nullif_compare_nested_to_unnested ( ) -> Result < ( ) > {
346+ let lhs = Arc :: new ( ListArray :: from_iter_primitive :: < Int64Type , _ , _ > ( vec ! [
347+ Some ( vec![ Some ( 1 ) , Some ( 2 ) ] ) ,
348+ Some ( vec![ Some ( 3 ) ] ) ,
349+ Some ( vec![ ] ) ,
350+ Some ( vec![ Some ( 5 ) , Some ( 6 ) , Some ( 7 ) ] ) ,
351+ None ,
352+ ] ) ) ;
353+ let lhs = ColumnarValue :: Array ( lhs) ;
354+
355+ let rhs = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) , Some ( 3 ) , None , None , None ] ) ) ;
356+ let rhs = ColumnarValue :: Array ( rhs) ;
357+
358+ let result = nullif_func ( & [ lhs, rhs] ) ;
359+
360+ assert ! ( matches!( result, Err ( DataFusionError :: ArrowError ( _, _) ) ) ) ;
361+
362+ Ok ( ( ) )
363+ }
364+
258365 #[ test]
259366 fn nullif_literal_first ( ) -> Result < ( ) > {
260367 let a = Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None , None , Some ( 3 ) , Some ( 4 ) ] ) ;
0 commit comments