Skip to content

Commit ded539c

Browse files
committed
fix revision
1 parent b4303ee commit ded539c

2 files changed

Lines changed: 62 additions & 48 deletions

File tree

datafusion/spark/src/function/string/concat_ws.rs

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ impl ScalarUDFImpl for SparkConcatWs {
111111

112112
// Use our implementation for all cases to guarantee consistent Utf8 return type.
113113
// Core's concat_ws may return Utf8View which conflicts with our return_type.
114-
spark_concat_ws_with_arrays(&args.args)
114+
spark_concat_ws_with_arrays(&args.args, args.number_rows)
115115
}
116116

117117
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
@@ -147,16 +147,10 @@ impl ScalarUDFImpl for SparkConcatWs {
147147
}
148148

149149
/// Implementation of concat_ws that supports array arguments.
150-
fn spark_concat_ws_with_arrays(args: &[ColumnarValue]) -> Result<ColumnarValue> {
151-
// Determine number of rows
152-
let num_rows = args
153-
.iter()
154-
.find_map(|x| match x {
155-
ColumnarValue::Array(a) => Some(a.len()),
156-
_ => None,
157-
})
158-
.unwrap_or(1);
159-
150+
fn spark_concat_ws_with_arrays(
151+
args: &[ColumnarValue],
152+
num_rows: usize,
153+
) -> Result<ColumnarValue> {
160154
// Convert all to arrays for uniform processing
161155
let arrays: Vec<ArrayRef> = args
162156
.iter()
@@ -237,10 +231,10 @@ fn collect_parts(arr: &ArrayRef, row_idx: usize, parts: &mut Vec<String>) -> Res
237231
parts.push(str_arr.value(row_idx).to_string());
238232
}
239233
DataType::List(_) => {
240-
collect_parts_from_list::<i32>(arr.as_list(), row_idx, parts)?;
234+
collect_parts_from_list::<i32>(arr.as_list::<i32>(), row_idx, parts)?;
241235
}
242236
DataType::LargeList(_) => {
243-
collect_parts_from_list::<i64>(arr.as_list(), row_idx, parts)?;
237+
collect_parts_from_list::<i64>(arr.as_list::<i64>(), row_idx, parts)?;
244238
}
245239
other => {
246240
return exec_err!("concat_ws does not support data type {other:?}");
@@ -347,12 +341,15 @@ mod tests {
347341
let b: ArrayRef = Arc::new(StringArray::from(vec!["b"]));
348342
let c: ArrayRef = Arc::new(StringArray::from(vec!["c"]));
349343

350-
let result = spark_concat_ws_with_arrays(&[
351-
ColumnarValue::Array(sep),
352-
ColumnarValue::Array(a),
353-
ColumnarValue::Array(b),
354-
ColumnarValue::Array(c),
355-
])?;
344+
let result = spark_concat_ws_with_arrays(
345+
&[
346+
ColumnarValue::Array(sep),
347+
ColumnarValue::Array(a),
348+
ColumnarValue::Array(b),
349+
ColumnarValue::Array(c),
350+
],
351+
1,
352+
)?;
356353

357354
match result {
358355
ColumnarValue::Array(arr) => {
@@ -371,12 +368,15 @@ mod tests {
371368
let b: ArrayRef = Arc::new(StringArray::from(vec![None::<&str>]));
372369
let c: ArrayRef = Arc::new(StringArray::from(vec![Some("c")]));
373370

374-
let result = spark_concat_ws_with_arrays(&[
375-
ColumnarValue::Array(sep),
376-
ColumnarValue::Array(a),
377-
ColumnarValue::Array(b),
378-
ColumnarValue::Array(c),
379-
])?;
371+
let result = spark_concat_ws_with_arrays(
372+
&[
373+
ColumnarValue::Array(sep),
374+
ColumnarValue::Array(a),
375+
ColumnarValue::Array(b),
376+
ColumnarValue::Array(c),
377+
],
378+
1,
379+
)?;
380380

381381
match result {
382382
ColumnarValue::Array(arr) => {
@@ -393,10 +393,10 @@ mod tests {
393393
let sep: ArrayRef = Arc::new(StringArray::from(vec![None::<&str>]));
394394
let a: ArrayRef = Arc::new(StringArray::from(vec![Some("a")]));
395395

396-
let result = spark_concat_ws_with_arrays(&[
397-
ColumnarValue::Array(sep),
398-
ColumnarValue::Array(a),
399-
])?;
396+
let result = spark_concat_ws_with_arrays(
397+
&[ColumnarValue::Array(sep), ColumnarValue::Array(a)],
398+
1,
399+
)?;
400400

401401
match result {
402402
ColumnarValue::Array(arr) => {
@@ -414,10 +414,10 @@ mod tests {
414414
let list = make_list_array(vec![Some(vec![Some("a"), Some("b"), Some("c")])]);
415415
let list_ref: ArrayRef = Arc::new(list);
416416

417-
let result = spark_concat_ws_with_arrays(&[
418-
ColumnarValue::Array(sep),
419-
ColumnarValue::Array(list_ref),
420-
])?;
417+
let result = spark_concat_ws_with_arrays(
418+
&[ColumnarValue::Array(sep), ColumnarValue::Array(list_ref)],
419+
1,
420+
)?;
421421

422422
match result {
423423
ColumnarValue::Array(arr) => {
@@ -435,10 +435,10 @@ mod tests {
435435
let list = make_list_array(vec![Some(vec![Some("a"), None, Some("c")])]);
436436
let list_ref: ArrayRef = Arc::new(list);
437437

438-
let result = spark_concat_ws_with_arrays(&[
439-
ColumnarValue::Array(sep),
440-
ColumnarValue::Array(list_ref),
441-
])?;
438+
let result = spark_concat_ws_with_arrays(
439+
&[ColumnarValue::Array(sep), ColumnarValue::Array(list_ref)],
440+
1,
441+
)?;
442442

443443
match result {
444444
ColumnarValue::Array(arr) => {
@@ -458,12 +458,15 @@ mod tests {
458458
let list_ref: ArrayRef = Arc::new(list);
459459
let y: ArrayRef = Arc::new(StringArray::from(vec!["y"]));
460460

461-
let result = spark_concat_ws_with_arrays(&[
462-
ColumnarValue::Array(sep),
463-
ColumnarValue::Array(x),
464-
ColumnarValue::Array(list_ref),
465-
ColumnarValue::Array(y),
466-
])?;
461+
let result = spark_concat_ws_with_arrays(
462+
&[
463+
ColumnarValue::Array(sep),
464+
ColumnarValue::Array(x),
465+
ColumnarValue::Array(list_ref),
466+
ColumnarValue::Array(y),
467+
],
468+
1,
469+
)?;
467470

468471
match result {
469472
ColumnarValue::Array(arr) => {
@@ -482,11 +485,14 @@ mod tests {
482485
let b: ArrayRef =
483486
Arc::new(StringArray::from(vec![Some("b"), Some("y"), Some("z")]));
484487

485-
let result = spark_concat_ws_with_arrays(&[
486-
ColumnarValue::Array(sep),
487-
ColumnarValue::Array(a),
488-
ColumnarValue::Array(b),
489-
])?;
488+
let result = spark_concat_ws_with_arrays(
489+
&[
490+
ColumnarValue::Array(sep),
491+
ColumnarValue::Array(a),
492+
ColumnarValue::Array(b),
493+
],
494+
3,
495+
)?;
490496

491497
match result {
492498
ColumnarValue::Array(arr) => {

datafusion/sqllogictest/test_files/spark/string/concat_ws.slt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,11 @@ SELECT concat_ws(',', a, b) AS result FROM VALUES ('a', 'b'), ('c', CAST(NULL AS
129129
a,b
130130
c
131131
d
132+
133+
## Scalar-only arguments over multiple rows (broadcast test)
134+
query T
135+
SELECT concat_ws(',', 'a', 'b') AS result FROM VALUES (1), (2), (3) AS t(x);
136+
----
137+
a,b
138+
a,b
139+
a,b

0 commit comments

Comments
 (0)