Skip to content

Commit 2f90194

Browse files
lyne7-scmartin-gJefffrey
authored
Fix array_repeat handling of null count values (#20102)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20075. ## Rationale for this change The previous implementation of `array_repeat` relied on Arrow defaults when handling null and negative count values. As a result, null counts were implicitly treated as zero and returned empty arrays, which is a correctness issue. This PR makes the handling of these edge cases explicit and aligns the function with SQL null semantics. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Explicit handling of null and negative count values - Planner-time coercion of the count argument to `Int64` <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, SLTs added and pass. ## Are there any user-facing changes? Yes. When the count value is null, `array_repeat` now returns a null array instead of an empty array. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com> Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent 71bc68f commit 2f90194

2 files changed

Lines changed: 164 additions & 65 deletions

File tree

datafusion/functions-nested/src/repeat.rs

Lines changed: 86 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,23 @@
1919
2020
use crate::utils::make_scalar_function;
2121
use arrow::array::{
22-
Array, ArrayRef, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, UInt64Array,
22+
Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait,
23+
UInt64Array,
2324
};
2425
use arrow::buffer::{NullBuffer, OffsetBuffer};
2526
use arrow::compute;
26-
use arrow::compute::cast;
2727
use arrow::datatypes::DataType;
2828
use arrow::datatypes::{
2929
DataType::{LargeList, List},
3030
Field,
3131
};
32-
use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array};
33-
use datafusion_common::{Result, exec_err, utils::take_function_args};
32+
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33+
use datafusion_common::types::{NativeType, logical_int64};
34+
use datafusion_common::{DataFusionError, Result};
3435
use datafusion_expr::{
3536
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3637
};
38+
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
3739
use datafusion_macros::user_doc;
3840
use std::any::Any;
3941
use std::sync::Arc;
@@ -88,7 +90,17 @@ impl Default for ArrayRepeat {
8890
impl ArrayRepeat {
8991
pub fn new() -> Self {
9092
Self {
91-
signature: Signature::user_defined(Volatility::Immutable),
93+
signature: Signature::coercible(
94+
vec![
95+
Coercion::new_exact(TypeSignatureClass::Any),
96+
Coercion::new_implicit(
97+
TypeSignatureClass::Native(logical_int64()),
98+
vec![TypeSignatureClass::Integer],
99+
NativeType::Int64,
100+
),
101+
],
102+
Volatility::Immutable,
103+
),
92104
aliases: vec![String::from("list_repeat")],
93105
}
94106
}
@@ -132,39 +144,14 @@ impl ScalarUDFImpl for ArrayRepeat {
132144
&self.aliases
133145
}
134146

135-
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
136-
let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
137-
138-
// Coerce the second argument to Int64/UInt64 if it's a numeric type
139-
let second = match second_type {
140-
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
141-
DataType::Int64
142-
}
143-
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
144-
DataType::UInt64
145-
}
146-
_ => return exec_err!("count must be an integer type"),
147-
};
148-
149-
Ok(vec![first_type.clone(), second])
150-
}
151-
152147
fn documentation(&self) -> Option<&Documentation> {
153148
self.doc()
154149
}
155150
}
156151

157152
fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
158153
let element = &args[0];
159-
let count_array = &args[1];
160-
161-
let count_array = match count_array.data_type() {
162-
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
163-
DataType::UInt64 => count_array,
164-
_ => return exec_err!("count must be an integer type"),
165-
};
166-
167-
let count_array = as_uint64_array(count_array)?;
154+
let count_array = as_int64_array(&args[1])?;
168155

169156
match element.data_type() {
170157
List(_) => {
@@ -193,21 +180,31 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
193180
/// ```
194181
fn general_repeat<O: OffsetSizeTrait>(
195182
array: &ArrayRef,
196-
count_array: &UInt64Array,
183+
count_array: &Int64Array,
197184
) -> Result<ArrayRef> {
198-
// Build offsets and take_indices
199-
let total_repeated_values: usize =
200-
count_array.values().iter().map(|&c| c as usize).sum();
185+
let total_repeated_values: usize = (0..count_array.len())
186+
.map(|i| get_count_with_validity(count_array, i))
187+
.sum();
188+
201189
let mut take_indices = Vec::with_capacity(total_repeated_values);
202190
let mut offsets = Vec::with_capacity(count_array.len() + 1);
203191
offsets.push(O::zero());
204192
let mut running_offset = 0usize;
205193

206-
for (idx, &count) in count_array.values().iter().enumerate() {
207-
let count = count as usize;
208-
running_offset += count;
209-
offsets.push(O::from_usize(running_offset).unwrap());
210-
take_indices.extend(std::iter::repeat_n(idx as u64, count))
194+
for idx in 0..count_array.len() {
195+
let count = get_count_with_validity(count_array, idx);
196+
running_offset = running_offset.checked_add(count).ok_or_else(|| {
197+
DataFusionError::Execution(
198+
"array_repeat: running_offset overflowed usize".to_string(),
199+
)
200+
})?;
201+
let offset = O::from_usize(running_offset).ok_or_else(|| {
202+
DataFusionError::Execution(format!(
203+
"array_repeat: offset {running_offset} exceeds the maximum value for offset type"
204+
))
205+
})?;
206+
offsets.push(offset);
207+
take_indices.extend(std::iter::repeat_n(idx as u64, count));
211208
}
212209

213210
// Build the flattened values
@@ -222,7 +219,7 @@ fn general_repeat<O: OffsetSizeTrait>(
222219
Arc::new(Field::new_list_field(array.data_type().to_owned(), true)),
223220
OffsetBuffer::new(offsets.into()),
224221
repeated_values,
225-
None,
222+
count_array.nulls().cloned(),
226223
)?))
227224
}
228225

@@ -238,23 +235,24 @@ fn general_repeat<O: OffsetSizeTrait>(
238235
/// ```
239236
fn general_list_repeat<O: OffsetSizeTrait>(
240237
list_array: &GenericListArray<O>,
241-
count_array: &UInt64Array,
238+
count_array: &Int64Array,
242239
) -> Result<ArrayRef> {
243-
let counts = count_array.values();
244240
let list_offsets = list_array.value_offsets();
245241

246242
// calculate capacities for pre-allocation
247-
let outer_total = counts.iter().map(|&c| c as usize).sum();
248-
let inner_total = counts
249-
.iter()
250-
.enumerate()
251-
.filter(|&(i, _)| !list_array.is_null(i))
252-
.map(|(i, &c)| {
253-
let len = list_offsets[i + 1].to_usize().unwrap()
254-
- list_offsets[i].to_usize().unwrap();
255-
len * (c as usize)
256-
})
257-
.sum();
243+
let mut outer_total = 0usize;
244+
let mut inner_total = 0usize;
245+
for i in 0..count_array.len() {
246+
let count = get_count_with_validity(count_array, i);
247+
if count > 0 {
248+
outer_total += count;
249+
if list_array.is_valid(i) {
250+
let len = list_offsets[i + 1].to_usize().unwrap()
251+
- list_offsets[i].to_usize().unwrap();
252+
inner_total += len * count;
253+
}
254+
}
255+
}
258256

259257
// Build inner structures
260258
let mut inner_offsets = Vec::with_capacity(outer_total + 1);
@@ -263,17 +261,27 @@ fn general_list_repeat<O: OffsetSizeTrait>(
263261
let mut inner_running = 0usize;
264262
inner_offsets.push(O::zero());
265263

266-
for (row_idx, &count) in counts.iter().enumerate() {
267-
let is_valid = !list_array.is_null(row_idx);
264+
for row_idx in 0..count_array.len() {
265+
let count = get_count_with_validity(count_array, row_idx);
266+
let list_is_valid = list_array.is_valid(row_idx);
268267
let start = list_offsets[row_idx].to_usize().unwrap();
269268
let end = list_offsets[row_idx + 1].to_usize().unwrap();
270269
let row_len = end - start;
271270

272271
for _ in 0..count {
273-
inner_running += row_len;
274-
inner_offsets.push(O::from_usize(inner_running).unwrap());
275-
inner_nulls.append(is_valid);
276-
if is_valid {
272+
inner_running = inner_running.checked_add(row_len).ok_or_else(|| {
273+
DataFusionError::Execution(
274+
"array_repeat: inner offset overflowed usize".to_string(),
275+
)
276+
})?;
277+
let offset = O::from_usize(inner_running).ok_or_else(|| {
278+
DataFusionError::Execution(format!(
279+
"array_repeat: offset {inner_running} exceeds the maximum value for offset type"
280+
))
281+
})?;
282+
inner_offsets.push(offset);
283+
inner_nulls.append(list_is_valid);
284+
if list_is_valid {
277285
take_indices.extend(start as u64..end as u64);
278286
}
279287
}
@@ -298,8 +306,24 @@ fn general_list_repeat<O: OffsetSizeTrait>(
298306
list_array.data_type().to_owned(),
299307
true,
300308
)),
301-
OffsetBuffer::<O>::from_lengths(counts.iter().map(|&c| c as usize)),
309+
OffsetBuffer::<O>::from_lengths(
310+
count_array
311+
.iter()
312+
.map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)),
313+
),
302314
Arc::new(inner_list),
303-
None,
315+
count_array.nulls().cloned(),
304316
)?))
305317
}
318+
319+
/// Helper function to get count from count_array at given index
320+
/// Return 0 for null values or non-positive count.
321+
#[inline]
322+
fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize {
323+
if count_array.is_null(idx) {
324+
0
325+
} else {
326+
let c = count_array.value(idx);
327+
if c > 0 { c as usize } else { 0 }
328+
}
329+
}

datafusion/sqllogictest/test_files/array.slt

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3256,24 +3256,99 @@ drop table array_repeat_table;
32563256
statement ok
32573257
drop table large_array_repeat_table;
32583258

3259-
3259+
# array_repeat: arrays with NULL counts
32603260
statement ok
32613261
create table array_repeat_null_count_table
32623262
as values
32633263
(1, 2),
32643264
(2, null),
3265-
(3, 1);
3265+
(3, 1),
3266+
(4, -1),
3267+
(null, null);
32663268

32673269
query I?
32683270
select column1, array_repeat(column1, column2) from array_repeat_null_count_table;
32693271
----
32703272
1 [1, 1]
3271-
2 []
3273+
2 NULL
32723274
3 [3]
3275+
4 []
3276+
NULL NULL
32733277

32743278
statement ok
32753279
drop table array_repeat_null_count_table
32763280

3281+
# array_repeat: nested arrays with NULL counts
3282+
statement ok
3283+
create table array_repeat_nested_null_count_table
3284+
as values
3285+
([[1, 2], [3, 4]], 2),
3286+
([[5, 6], [7, 8]], null),
3287+
([[null, null], [9, 10]], 1),
3288+
(null, 3),
3289+
([[11, 12]], -1);
3290+
3291+
query ??
3292+
select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table;
3293+
----
3294+
[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
3295+
[[5, 6], [7, 8]] NULL
3296+
[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]]
3297+
NULL [NULL, NULL, NULL]
3298+
[[11, 12]] []
3299+
3300+
statement ok
3301+
drop table array_repeat_nested_null_count_table
3302+
3303+
# array_repeat edge cases: empty arrays
3304+
query ???
3305+
select array_repeat([], 3), array_repeat([], 0), array_repeat([], null);
3306+
----
3307+
[[], [], []] [] NULL
3308+
3309+
query ??
3310+
select array_repeat(null::int, 0), array_repeat(null::int, null);
3311+
----
3312+
[] NULL
3313+
3314+
# array_repeat LargeList with NULL count
3315+
statement ok
3316+
create table array_repeat_large_list_null_table
3317+
as values
3318+
(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2),
3319+
(arrow_cast([4, 5], 'LargeList(Int64)'), null),
3320+
(arrow_cast(null, 'LargeList(Int64)'), 3);
3321+
3322+
query ??
3323+
select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table;
3324+
----
3325+
[1, 2, 3] [[1, 2, 3], [1, 2, 3]]
3326+
[4, 5] NULL
3327+
NULL [NULL, NULL, NULL]
3328+
3329+
statement ok
3330+
drop table array_repeat_large_list_null_table
3331+
3332+
# array_repeat edge cases: LargeList nested with NULL count
3333+
statement ok
3334+
create table array_repeat_large_nested_null_table
3335+
as values
3336+
(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2),
3337+
(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null),
3338+
(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1),
3339+
(null, 3);
3340+
3341+
query ??
3342+
select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table;
3343+
----
3344+
[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
3345+
[[5, 6], [7, 8]] NULL
3346+
[[NULL, NULL]] [[[NULL, NULL]]]
3347+
NULL [NULL, NULL, NULL]
3348+
3349+
statement ok
3350+
drop table array_repeat_large_nested_null_table
3351+
32773352
## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)
32783353

32793354
# test with empty array

0 commit comments

Comments
 (0)