Skip to content

Commit 687d16a

Browse files
tabacgabotechs
andcommitted
Add support for nested types to nullif. (apache#21764)
- Closes apache#21763 Add support for nested types to the `nullif` UDF. Unit tests included. No changes to the function's signature. --------- Co-authored-by: Gabriel <45515538+gabotechs@users.noreply.github.com> (cherry picked from commit 22bb4e6)
1 parent c7668ba commit 687d16a

3 files changed

Lines changed: 114 additions & 5 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ datafusion-execution = { workspace = true }
7878
datafusion-expr = { workspace = true }
7979
datafusion-expr-common = { workspace = true }
8080
datafusion-macros = { workspace = true }
81+
datafusion-physical-expr-common = { workspace = true }
8182
hex = { workspace = true, optional = true }
8283
itertools = { workspace = true }
8384
log = { workspace = true }

datafusion/functions/src/core/nullif.rs

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
use arrow::datatypes::DataType;
1919
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
2020

21-
use arrow::compute::kernels::cmp::eq;
2221
use arrow::compute::kernels::nullif::nullif;
2322
use datafusion_common::{Result, ScalarValue, utils::take_function_args};
2423
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2524
use datafusion_macros::user_doc;
25+
use datafusion_physical_expr_common::datum::compare_with_eq;
2626
use std::any::Any;
2727

2828
#[user_doc(
@@ -115,25 +115,29 @@ impl ScalarUDFImpl for NullIfFunc {
115115
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
116116
fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
117117
let [lhs, rhs] = take_function_args("nullif", args)?;
118+
let is_nested = lhs.data_type().is_nested();
118119

119120
match (lhs, rhs) {
120121
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
121122
let rhs = rhs.to_scalar()?;
122-
let array = nullif(lhs, &eq(&lhs, &rhs)?)?;
123+
let eq_array = compare_with_eq(lhs, &rhs, is_nested)?;
124+
let array = nullif(lhs, &eq_array)?;
123125

124126
Ok(ColumnarValue::Array(array))
125127
}
126128
(ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
127-
let array = nullif(lhs, &eq(&lhs, &rhs)?)?;
129+
let eq_array = compare_with_eq(lhs, rhs, is_nested)?;
130+
let array = nullif(lhs, &eq_array)?;
128131
Ok(ColumnarValue::Array(array))
129132
}
130133
(ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
131134
let lhs_s = lhs.to_scalar()?;
132135
let lhs_a = lhs.to_array_of_size(rhs.len())?;
136+
let eq_array = compare_with_eq(&lhs_s, rhs, is_nested)?;
133137
let array = nullif(
134138
// nullif in arrow-select does not support Datum, so we need to convert to array
135139
lhs_a.as_ref(),
136-
&eq(&lhs_s, &rhs)?,
140+
&eq_array,
137141
)?;
138142
Ok(ColumnarValue::Array(array))
139143
}
@@ -152,7 +156,12 @@ fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
152156
mod tests {
153157
use std::sync::Arc;
154158

155-
use arrow::array::*;
159+
use arrow::{
160+
array::*,
161+
buffer::NullBuffer,
162+
datatypes::{Field, Fields, Int64Type},
163+
};
164+
use datafusion_common::DataFusionError;
156165

157166
use super::*;
158167

@@ -255,6 +264,104 @@ mod tests {
255264
Ok(())
256265
}
257266

267+
#[test]
268+
fn nullif_struct() -> Result<()> {
269+
let fields = Fields::from(vec![
270+
Field::new("a", DataType::Int64, true),
271+
Field::new("b", DataType::Utf8, true),
272+
]);
273+
274+
let lhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(2), None]));
275+
let lhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None]));
276+
let lhs_nulls = Some(NullBuffer::from(vec![true, true, false]));
277+
let lhs = ColumnarValue::Array(Arc::new(StructArray::new(
278+
fields.clone(),
279+
vec![lhs_a, lhs_b],
280+
lhs_nulls,
281+
)));
282+
283+
let rhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(9), None]));
284+
let rhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None]));
285+
let rhs_nulls = Some(NullBuffer::from(vec![true, true, false]));
286+
let rhs = ColumnarValue::Array(Arc::new(StructArray::new(
287+
fields.clone(),
288+
vec![rhs_a, rhs_b],
289+
rhs_nulls,
290+
)));
291+
292+
let result = nullif_func(&[lhs, rhs])?;
293+
let result = result.into_array(0).expect("Failed to convert to array");
294+
295+
let expected_arrays = vec![
296+
Arc::new(Int64Array::from(vec![None, Some(2), None])) as ArrayRef,
297+
Arc::new(StringArray::from(vec![None, Some("2"), None])) as ArrayRef,
298+
];
299+
let expected_nulls = NullBuffer::from(vec![false, true, false]);
300+
301+
let expected = Arc::new(StructArray::try_new(
302+
fields,
303+
expected_arrays,
304+
Some(expected_nulls),
305+
)?) as ArrayRef;
306+
307+
assert_eq!(expected.as_ref(), result.as_ref());
308+
309+
Ok(())
310+
}
311+
312+
#[test]
313+
fn nullif_list() -> Result<()> {
314+
let lhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
315+
Some(vec![Some(1), Some(2)]),
316+
Some(vec![Some(3)]),
317+
Some(vec![]),
318+
Some(vec![Some(5), Some(6), Some(7)]),
319+
None,
320+
]));
321+
let lhs = ColumnarValue::Array(lhs);
322+
323+
let rhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
324+
Some(vec![Some(1), Some(2)]),
325+
]));
326+
let rhs = ColumnarValue::Scalar(ScalarValue::List(rhs));
327+
328+
let result = nullif_func(&[lhs, rhs])?;
329+
let result = result.into_array(0).expect("Failed to convert to array");
330+
331+
let expected = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
332+
None,
333+
Some(vec![Some(3)]),
334+
Some(vec![]),
335+
Some(vec![Some(5), Some(6), Some(7)]),
336+
None,
337+
])) as ArrayRef;
338+
339+
assert_eq!(expected.as_ref(), result.as_ref());
340+
341+
Ok(())
342+
}
343+
344+
#[test]
345+
fn nullif_compare_nested_to_unnested() -> Result<()> {
346+
let lhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
347+
Some(vec![Some(1), Some(2)]),
348+
Some(vec![Some(3)]),
349+
Some(vec![]),
350+
Some(vec![Some(5), Some(6), Some(7)]),
351+
None,
352+
]));
353+
let lhs = ColumnarValue::Array(lhs);
354+
355+
let rhs = Arc::new(Int64Array::from(vec![Some(1), Some(3), None, None, None]));
356+
let rhs = ColumnarValue::Array(rhs);
357+
358+
let result = nullif_func(&[lhs, rhs]);
359+
360+
assert!(matches!(result, Err(DataFusionError::ArrowError(_, _))));
361+
362+
Ok(())
363+
}
364+
258365
#[test]
259366
fn nullif_literal_first() -> Result<()> {
260367
let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);

0 commit comments

Comments
 (0)