Skip to content

Commit 22bb4e6

Browse files
tabacgabotechs
andauthored
Add support for nested types to nullif. (#21764)
## Which issue does this PR close? - Closes #21763 ## Rationale for this change Add support for nested types to the `nullif` UDF. ## Are these changes tested? Unit tests included. ## Are there any user-facing changes? No changes to the function's signature. --------- Co-authored-by: Gabriel <45515538+gabotechs@users.noreply.github.com>
1 parent c2c0773 commit 22bb4e6

3 files changed

Lines changed: 114 additions & 5 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ datafusion-execution = { workspace = true }
7878
datafusion-expr = { workspace = true }
7979
datafusion-expr-common = { workspace = true }
8080
datafusion-macros = { workspace = true }
81+
datafusion-physical-expr-common = { workspace = true }
8182
hex = { workspace = true, optional = true }
8283
itertools = { workspace = true }
8384
log = { workspace = true }

datafusion/functions/src/core/nullif.rs

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
use arrow::datatypes::DataType;
1919
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
2020

21-
use arrow::compute::kernels::cmp::eq;
2221
use arrow::compute::kernels::nullif::nullif;
2322
use datafusion_common::{Result, ScalarValue, utils::take_function_args};
2423
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2524
use datafusion_macros::user_doc;
25+
use datafusion_physical_expr_common::datum::compare_with_eq;
2626

2727
#[user_doc(
2828
doc_section(label = "Conditional Functions"),
@@ -111,25 +111,29 @@ impl ScalarUDFImpl for NullIfFunc {
111111
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
112112
fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
113113
let [lhs, rhs] = take_function_args("nullif", args)?;
114+
let is_nested = lhs.data_type().is_nested();
114115

115116
match (lhs, rhs) {
116117
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
117118
let rhs = rhs.to_scalar()?;
118-
let array = nullif(lhs, &eq(&lhs, &rhs)?)?;
119+
let eq_array = compare_with_eq(lhs, &rhs, is_nested)?;
120+
let array = nullif(lhs, &eq_array)?;
119121

120122
Ok(ColumnarValue::Array(array))
121123
}
122124
(ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
123-
let array = nullif(lhs, &eq(&lhs, &rhs)?)?;
125+
let eq_array = compare_with_eq(lhs, rhs, is_nested)?;
126+
let array = nullif(lhs, &eq_array)?;
124127
Ok(ColumnarValue::Array(array))
125128
}
126129
(ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
127130
let lhs_s = lhs.to_scalar()?;
128131
let lhs_a = lhs.to_array_of_size(rhs.len())?;
132+
let eq_array = compare_with_eq(&lhs_s, rhs, is_nested)?;
129133
let array = nullif(
130134
// nullif in arrow-select does not support Datum, so we need to convert to array
131135
lhs_a.as_ref(),
132-
&eq(&lhs_s, &rhs)?,
136+
&eq_array,
133137
)?;
134138
Ok(ColumnarValue::Array(array))
135139
}
@@ -148,7 +152,12 @@ fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
148152
mod tests {
149153
use std::sync::Arc;
150154

151-
use arrow::array::*;
155+
use arrow::{
156+
array::*,
157+
buffer::NullBuffer,
158+
datatypes::{Field, Fields, Int64Type},
159+
};
160+
use datafusion_common::DataFusionError;
152161

153162
use super::*;
154163

@@ -251,6 +260,104 @@ mod tests {
251260
Ok(())
252261
}
253262

263+
#[test]
264+
fn nullif_struct() -> Result<()> {
265+
let fields = Fields::from(vec![
266+
Field::new("a", DataType::Int64, true),
267+
Field::new("b", DataType::Utf8, true),
268+
]);
269+
270+
let lhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(2), None]));
271+
let lhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None]));
272+
let lhs_nulls = Some(NullBuffer::from(vec![true, true, false]));
273+
let lhs = ColumnarValue::Array(Arc::new(StructArray::new(
274+
fields.clone(),
275+
vec![lhs_a, lhs_b],
276+
lhs_nulls,
277+
)));
278+
279+
let rhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(9), None]));
280+
let rhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None]));
281+
let rhs_nulls = Some(NullBuffer::from(vec![true, true, false]));
282+
let rhs = ColumnarValue::Array(Arc::new(StructArray::new(
283+
fields.clone(),
284+
vec![rhs_a, rhs_b],
285+
rhs_nulls,
286+
)));
287+
288+
let result = nullif_func(&[lhs, rhs])?;
289+
let result = result.into_array(0).expect("Failed to convert to array");
290+
291+
let expected_arrays = vec![
292+
Arc::new(Int64Array::from(vec![None, Some(2), None])) as ArrayRef,
293+
Arc::new(StringArray::from(vec![None, Some("2"), None])) as ArrayRef,
294+
];
295+
let expected_nulls = NullBuffer::from(vec![false, true, false]);
296+
297+
let expected = Arc::new(StructArray::try_new(
298+
fields,
299+
expected_arrays,
300+
Some(expected_nulls),
301+
)?) as ArrayRef;
302+
303+
assert_eq!(expected.as_ref(), result.as_ref());
304+
305+
Ok(())
306+
}
307+
308+
#[test]
309+
fn nullif_list() -> Result<()> {
310+
let lhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
311+
Some(vec![Some(1), Some(2)]),
312+
Some(vec![Some(3)]),
313+
Some(vec![]),
314+
Some(vec![Some(5), Some(6), Some(7)]),
315+
None,
316+
]));
317+
let lhs = ColumnarValue::Array(lhs);
318+
319+
let rhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
320+
Some(vec![Some(1), Some(2)]),
321+
]));
322+
let rhs = ColumnarValue::Scalar(ScalarValue::List(rhs));
323+
324+
let result = nullif_func(&[lhs, rhs])?;
325+
let result = result.into_array(0).expect("Failed to convert to array");
326+
327+
let expected = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
328+
None,
329+
Some(vec![Some(3)]),
330+
Some(vec![]),
331+
Some(vec![Some(5), Some(6), Some(7)]),
332+
None,
333+
])) as ArrayRef;
334+
335+
assert_eq!(expected.as_ref(), result.as_ref());
336+
337+
Ok(())
338+
}
339+
340+
#[test]
341+
fn nullif_compare_nested_to_unnested() -> Result<()> {
342+
let lhs = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
343+
Some(vec![Some(1), Some(2)]),
344+
Some(vec![Some(3)]),
345+
Some(vec![]),
346+
Some(vec![Some(5), Some(6), Some(7)]),
347+
None,
348+
]));
349+
let lhs = ColumnarValue::Array(lhs);
350+
351+
let rhs = Arc::new(Int64Array::from(vec![Some(1), Some(3), None, None, None]));
352+
let rhs = ColumnarValue::Array(rhs);
353+
354+
let result = nullif_func(&[lhs, rhs]);
355+
356+
assert!(matches!(result, Err(DataFusionError::ArrowError(_, _))));
357+
358+
Ok(())
359+
}
360+
254361
#[test]
255362
fn nullif_literal_first() -> Result<()> {
256363
let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);

0 commit comments

Comments
 (0)