diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index d65620ede38e..3d2f76b5d7fe 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -43,7 +43,7 @@ use datafusion_common::cast::{ use datafusion_common::{Result, exec_err, utils::take_function_args}; use itertools::Itertools; -use crate::utils::{compare_element_to_list, make_scalar_function}; +use crate::utils::{compare_element_to_list_fixed, make_scalar_function}; make_udf_expr_and_func!( ArrayPosition, @@ -209,9 +209,15 @@ fn resolve_start_from( Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => { Ok(vec![v - 1; num_rows]) } + Some(ColumnarValue::Scalar(s)) if s.is_null() => { + exec_err!("array_position index cannot contain nulls") + } Some(ColumnarValue::Scalar(s)) => { exec_err!("array_position expected Int64 for start_from, got {s}") } + Some(ColumnarValue::Array(a)) if a.null_count() > 0 => { + exec_err!("array_position index cannot contain nulls") + } Some(ColumnarValue::Array(a)) => { Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect()) } @@ -306,11 +312,11 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result>() + let arr_from = as_int64_array(&args[2])?; + if arr_from.null_count() > 0 { + return exec_err!("array_position index cannot contain nulls"); + } + arr_from.values().iter().map(|&x| x - 1).collect::>() } else { vec![0; haystack.len()] }; @@ -321,10 +327,14 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result(haystack, needle, &arr_from) + if needle.data_type().is_list() { + generic_position::(haystack, needle, &arr_from) + } else { + generic_position::(haystack, needle, &arr_from) + } } -fn generic_position( +fn generic_position( haystack: &GenericListArray, needle: &ArrayRef, arr_from: &[i64], // 0-indexed @@ -332,12 +342,12 @@ fn generic_position( let mut data = Vec::with_capacity(haystack.len()); for (row_index, (row, &from)) in haystack.iter().zip(arr_from.iter()).enumerate() { - let from = from as usize; - if let Some(row) = row { - let eq_array = compare_element_to_list(&row, needle, row_index, true)?; + let eq_array = + compare_element_to_list_fixed::(&row, needle, row_index)?; // Collect `true`s in 1-indexed positions + let from = from as usize; let index = eq_array .iter() .skip(from) @@ -363,7 +373,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Searches for an element in the array, returns all occurrences.", + description = "Returns the positions of all occurrences of an element in the array. Returns an empty list `[]` if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL. Only returns NULL if the array to search itself is NULL.", syntax_example = "array_positions(array, element)", sql_example = r#"```sql > select array_positions([1, 2, 2, 3, 1, 4], 2); @@ -476,14 +486,24 @@ fn try_array_positions_scalar(args: &[ColumnarValue]) -> Result Result { let [haystack, needle] = take_function_args("array_positions", args)?; - match &haystack.data_type() { - List(_) => general_positions::(as_list_array(&haystack)?, needle), - LargeList(_) => general_positions::(as_large_list_array(&haystack)?, needle), - dt => exec_err!("array_positions does not support type '{dt}'"), + match (haystack.data_type(), needle.data_type().is_list()) { + (List(_), true) => { + general_positions::(as_list_array(&haystack)?, needle) + } + (LargeList(_), true) => { + general_positions::(as_large_list_array(&haystack)?, needle) + } + (List(_), false) => { + general_positions::(as_list_array(&haystack)?, needle) + } + (LargeList(_), false) => { + general_positions::(as_large_list_array(&haystack)?, needle) + } + (dt, _) => exec_err!("array_positions does not support type '{dt}'"), } } -fn general_positions( +fn general_positions( haystack: &GenericListArray, needle: &ArrayRef, ) -> Result { @@ -492,7 +512,8 @@ fn general_positions( for (row_index, row) in haystack.iter().enumerate() { if let Some(row) = row { - let eq_array = compare_element_to_list(&row, needle, row_index, true)?; + let eq_array = + compare_element_to_list_fixed::(&row, needle, row_index)?; // Collect `true`s in 1-indexed positions let indexes = eq_array @@ -591,7 +612,7 @@ fn array_positions_scalar( #[cfg(test)] mod tests { use super::*; - use arrow::array::AsArray; + use arrow::array::{AsArray, Int32Array, new_empty_array}; use arrow::datatypes::Int32Type; use datafusion_common::config::ConfigOptions; @@ -750,4 +771,173 @@ mod tests { Ok(()) } + + #[test] + fn test_nested_non_empty_null() -> Result<()> { + // Haystack Needle array_position array_positionS + // [[7]] [null] null [] + // [[7]] null null [] + // [[7], null] [null] null [] + // [[7], null] null 2 [2] + // [[7], [null]] [null] 2 [2] + // [[7], [null], null] [null] 2 [2] + + // Nulls are not zero sized and have underlying value of 7 + + // [[7], [7], [7], null, [7], null, [7], [null], [7], [null], null] + let inner = Arc::new(ListArray::new( + Field::new_list_field(DataType::Int32, true).into(), + OffsetBuffer::from_lengths(vec![1; 11]), + Arc::new(Int32Array::new( + vec![7; 11].into(), + Some( + vec![ + true, true, true, true, true, true, true, false, true, false, + true, + ] + .into(), + ), + )), + Some( + vec![ + true, true, true, false, true, false, true, true, true, true, false, + ] + .into(), + ), + )); + + // [[[7]], [[7]], [[7], null], [[7], null], [[7], [null]], [[7], [null], null]] + let haystack: Arc = Arc::new(ListArray::new( + Field::new_list_field(inner.data_type().clone(), true).into(), + OffsetBuffer::from_lengths(vec![1, 1, 2, 2, 2, 3]), + inner, + None, + )); + + // [[null], null, [null], null, [null], [null]] + let needle: Arc = Arc::new(ListArray::new( + Field::new_list_field(DataType::Int32, true).into(), + OffsetBuffer::from_lengths(vec![1; 6]), + Arc::new(Int32Array::new( + vec![7; 6].into(), + Some(vec![false; 6].into()), + )), + Some(vec![true, false, true, false, true, true].into()), + )); + + let output = ArrayPosition::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&haystack)), + ColumnarValue::Array(Arc::clone(&needle)), + ], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("", DataType::Null, true)), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(9)?; + // [null, null, null, 2, 2, 2] + let expected: Arc = Arc::new(UInt64Array::from(vec![ + None, + None, + None, + Some(2), + Some(2), + Some(2), + ])); + assert_eq!(&output, &expected); + + let output = ArrayPositions::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(haystack), ColumnarValue::Array(needle)], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("", DataType::Null, true)), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(9)?; + // [[], [], [], [2], [2], [2]] + let expected: Arc = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![]), + Some(vec![]), + Some(vec![]), + Some(vec![Some(2)]), + Some(vec![Some(2)]), + Some(vec![Some(2)]), + ])); + assert_eq!(&output, &expected); + + Ok(()) + } + + #[test] + fn test_nested_empty_list() -> Result<()> { + // Haystack Needle array_position array_positionS + // [[]] null null [] + // [[7], []] [] 2 [2] + // [[7], null, []] [] 3 [3] + + // [[], [7], [], [7], null, []] + let inner = Arc::new(ListArray::new( + Field::new_list_field(DataType::Int32, true).into(), + OffsetBuffer::from_lengths(vec![0, 1, 0, 1, 0, 0]), + Arc::new(Int32Array::from(vec![7, 7])), + Some(vec![true, true, true, true, false, true].into()), + )); + + // [[[]], [[7], []], [[7], null, []]] + let haystack: Arc = Arc::new(ListArray::new( + Field::new_list_field(inner.data_type().clone(), true).into(), + OffsetBuffer::from_lengths(vec![1, 2, 3]), + inner, + None, + )); + + // [null, [], []] + let needle: Arc = Arc::new(ListArray::new( + Field::new_list_field(DataType::Int32, true).into(), + OffsetBuffer::from_lengths(vec![0, 0, 0]), + Arc::new(new_empty_array(&DataType::Int32)), + Some(vec![false, true, true].into()), + )); + + let output = ArrayPosition::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&haystack)), + ColumnarValue::Array(Arc::clone(&needle)), + ], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("", DataType::Null, true)), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(9)?; + // [null, 2, 3] + let expected: Arc = + Arc::new(UInt64Array::from(vec![None, Some(2), Some(3)])); + assert_eq!(&output, &expected); + + let output = ArrayPositions::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(haystack), ColumnarValue::Array(needle)], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("", DataType::Null, true)), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(9)?; + // [[], [2], [3]] + let expected: Arc = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![]), + Some(vec![Some(2)]), + Some(vec![Some(3)]), + ])); + assert_eq!(&output, &expected); + + Ok(()) + } } diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 9f46917a87eb..7fff83ca2170 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -19,10 +19,12 @@ use std::sync::Arc; +use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, + make_comparator, }; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::{ @@ -220,6 +222,30 @@ pub(crate) fn compare_element_to_list( Ok(res) } +/// Given a `haystack` array, and a specific value from `needle` selected by +/// `needle_element_index`, return a `BooleanArray` based on whether the elements +/// in `haystack` match the `needle` value using `IS NOT DISTINCT FROM` semantics. +/// - Allows NULL = NULL to be considered true +pub(crate) fn compare_element_to_list_fixed( + haystack: &dyn Array, + needle: &dyn Array, + needle_element_index: usize, +) -> Result { + if IS_LIST { + // arrow_ord::cmp::eq does not support ListArray, so we resort to make_comparator + let cmp = make_comparator(haystack, needle, SortOptions::default())?; + let res = (0..haystack.len()) + .map(|i| cmp(i, needle_element_index).is_eq()) + .collect::(); + Ok(res) + } else { + let needle = needle.slice(needle_element_index, 1); + let needle_value = Scalar::new(needle); + // use not_distinct so we can compare NULL + Ok(arrow_ord::cmp::not_distinct(&haystack, &needle_value)?) + } +} + /// Returns the length of each array dimension pub(crate) fn compute_array_dims( arr: Option, diff --git a/datafusion/sqllogictest/test_files/array/array_position.slt b/datafusion/sqllogictest/test_files/array/array_position.slt index 07e3d3143592..0482b83ea854 100644 --- a/datafusion/sqllogictest/test_files/array/array_position.slt +++ b/datafusion/sqllogictest/test_files/array/array_position.slt @@ -314,10 +314,9 @@ select array_positions([1, 2, 3, 4, 5], null); #TODO: https://github.com/apache/datafusion/issues/7142 # array_positions with NULL (follow PostgreSQL) -#query ? -#select array_positions(null, 1); -#---- -#NULL +# expected to return null +query error DataFusion error: Execution error: array_positions does not support type 'Null' +select array_positions(null, 1); # array_positions scalar function #1 query ??? @@ -458,4 +457,74 @@ select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array [1] [] +query II? rowsort +select id, array_position(haystack, needle), array_positions(haystack, needle) +from values + (1, [1, 2], 1), -- element found, return position + (2, [1, 2], 0), -- element missing, return null & empty list + (3, [1, 2], null), -- null is an element, which is missing, return as above + (4, [1, null], 1), -- element found (regardless of null), return position + (5, [1, null], 0), -- element missing, return null & empty list + (6, [1, null], null), -- null is found as element so return position + (7, null, 1), -- when haystack is null we always return null + (8, null, null) -- ^ +t(id, haystack, needle); +---- +1 1 [1] +2 NULL [] +3 NULL [] +4 1 [1] +5 NULL [] +6 2 [2] +7 NULL NULL +8 NULL NULL + +query TI? rowsort +select id, array_position(haystack, needle), array_positions(haystack, needle) +from values + ('01', [[1], [2]], [1]), -- [1] is the sublist element, found it so return position + ('02', [[1], [2]], [0]), -- not found, return null & [] + ('03', [[1], [2]], [null]), -- [null] is element, not found as above + ('04', [[1], [2]], null), -- not found, return null & [] + ('05', [[1], null], [1]), -- [1] is found even though we have null, return position + ('06', [[1], null], [0]), -- not found, return null & [] + ('07', [[1], null], [null]), -- ^ + ('08', [[1], null], null), -- null itself is an element, which exists so return position + ('09', [[]], null), -- no null element, return null & [] + ('10', [[1], [null]], [null]), -- [null] is found, return position + ('11', [[1], [null], null], [null]), -- ^ (ignore regular null) + ('12', [[1], []], []), -- ^ ([] is element) + ('13', [[1], null, []], []), -- ^ (ignore regular null) + ('14', null, [1]), -- haystack null, return null + ('15', null, [null]), -- ^ + ('16', null, null) -- ^ +t(id, haystack, needle); +---- +01 1 [1] +02 NULL [] +03 NULL [] +04 NULL [] +05 1 [1] +06 NULL [] +07 NULL [] +08 2 [2] +09 NULL [] +10 2 [2] +11 2 [2] +12 2 [2] +13 3 [3] +14 NULL NULL +15 NULL NULL +16 NULL NULL + +query error DataFusion error: Execution error: array_position index cannot contain nulls +select array_position(haystack, needle, index_from) +from values + ([1, 2], 1, 2), + ([1, 2], 1, null) +t(haystack, needle, index_from) + +query error DataFusion error: Execution error: array_position index cannot contain nulls +select array_position([1, 2], 1, null) + include ./cleanup.slt.part diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ffa04c3013f8..062c0c207c0e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3938,7 +3938,7 @@ array_position(array, element, index) ### `array_positions` -Searches for an element in the array, returns all occurrences. +Returns the positions of all occurrences of an element in the array. Returns an empty list `[]` if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL. Only returns NULL if the array to search itself is NULL. ```sql array_positions(array, element)