Skip to content

Commit 0ab78e7

Browse files
authored
fix(spark): array_repeat returns repeated NULLs instead of NULL when element is NULL (apache#21558)
## Which issue does this PR close? - Closes apache#21512. ## Rationale for this change NULL behavior does not align at the moment and with this PR it is. You can see issue ## What changes are included in this PR? only check count argument for null for returning null ## Are these changes tested? yes slt tests ## Are there any user-facing changes? Now user's will be able to see NULL element is repeated as in spark
1 parent ec00112 commit 0ab78e7

2 files changed

Lines changed: 17 additions & 9 deletions

File tree

datafusion/spark/src/function/array/repeat.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::function::null_utils::{
2828
NullMaskResolution, apply_null_mask, compute_null_mask,
2929
};
3030

31-
/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL.
31+
/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL count: in Spark if the count is NULL, the result is NULL.
3232
/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
3333
#[derive(Debug, PartialEq, Eq, Hash)]
3434
pub struct SparkArrayRepeat {
@@ -88,7 +88,7 @@ impl ScalarUDFImpl for SparkArrayRepeat {
8888
}
8989

9090
/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL
91-
/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
91+
/// if the count argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
9292
fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
9393
let ScalarFunctionArgs {
9494
args: arg_values,
@@ -99,15 +99,14 @@ fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
9999
} = args;
100100
let return_type = return_field.data_type().clone();
101101

102-
// Step 1: Check for NULL mask in incoming args
103-
let null_mask = compute_null_mask(&arg_values, number_rows)?;
102+
// A NULL element should be repeated into the array, not cause a NULL result.
103+
let null_mask = compute_null_mask(&arg_values[1..], number_rows)?;
104104

105-
// If any argument is null then return NULL immediately
105+
// If count is null then return NULL immediately
106106
if matches!(null_mask, NullMaskResolution::ReturnNull) {
107107
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
108108
}
109109

110-
// Step 2: Delegate to DataFusion's array_repeat
111110
let array_repeat_func = ArrayRepeat::new();
112111
let func_args = ScalarFunctionArgs {
113112
args: arg_values,
@@ -118,6 +117,5 @@ fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
118117
};
119118
let result = array_repeat_func.invoke_with_args(func_args)?;
120119

121-
// Step 3: Apply NULL mask to result
122120
apply_null_mask(result, null_mask, &return_type)
123121
}

datafusion/sqllogictest/test_files/spark/array/array_repeat.slt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,17 @@ SELECT array_repeat(['123'], 2);
5959
query ?
6060
SELECT array_repeat(NULL, 2);
6161
----
62-
NULL
62+
[NULL, NULL]
63+
64+
query ?
65+
SELECT array_repeat(NULL, 1);
66+
----
67+
[NULL]
68+
69+
query ?
70+
SELECT array_repeat(NULL, 0);
71+
----
72+
[]
6373

6474
query ?
6575
SELECT array_repeat([NULL], 2);
@@ -88,7 +98,7 @@ FROM VALUES
8898
[123, 123]
8999
[]
90100
[]
91-
NULL
101+
[NULL]
92102
NULL
93103

94104

0 commit comments

Comments
 (0)