@@ -111,7 +111,7 @@ impl ScalarUDFImpl for SparkConcatWs {
111111
112112 // Use our implementation for all cases to guarantee consistent Utf8 return type.
113113 // Core's concat_ws may return Utf8View which conflicts with our return_type.
114- spark_concat_ws_with_arrays ( & args. args )
114+ spark_concat_ws_with_arrays ( & args. args , args . number_rows )
115115 }
116116
117117 fn coerce_types ( & self , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
@@ -147,16 +147,10 @@ impl ScalarUDFImpl for SparkConcatWs {
147147}
148148
149149/// Implementation of concat_ws that supports array arguments.
150- fn spark_concat_ws_with_arrays ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
151- // Determine number of rows
152- let num_rows = args
153- . iter ( )
154- . find_map ( |x| match x {
155- ColumnarValue :: Array ( a) => Some ( a. len ( ) ) ,
156- _ => None ,
157- } )
158- . unwrap_or ( 1 ) ;
159-
150+ fn spark_concat_ws_with_arrays (
151+ args : & [ ColumnarValue ] ,
152+ num_rows : usize ,
153+ ) -> Result < ColumnarValue > {
160154 // Convert all to arrays for uniform processing
161155 let arrays: Vec < ArrayRef > = args
162156 . iter ( )
@@ -237,10 +231,10 @@ fn collect_parts(arr: &ArrayRef, row_idx: usize, parts: &mut Vec<String>) -> Res
237231 parts. push ( str_arr. value ( row_idx) . to_string ( ) ) ;
238232 }
239233 DataType :: List ( _) => {
240- collect_parts_from_list :: < i32 > ( arr. as_list ( ) , row_idx, parts) ?;
234+ collect_parts_from_list :: < i32 > ( arr. as_list :: < i32 > ( ) , row_idx, parts) ?;
241235 }
242236 DataType :: LargeList ( _) => {
243- collect_parts_from_list :: < i64 > ( arr. as_list ( ) , row_idx, parts) ?;
237+ collect_parts_from_list :: < i64 > ( arr. as_list :: < i64 > ( ) , row_idx, parts) ?;
244238 }
245239 other => {
246240 return exec_err ! ( "concat_ws does not support data type {other:?}" ) ;
@@ -347,12 +341,15 @@ mod tests {
347341 let b: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "b" ] ) ) ;
348342 let c: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "c" ] ) ) ;
349343
350- let result = spark_concat_ws_with_arrays ( & [
351- ColumnarValue :: Array ( sep) ,
352- ColumnarValue :: Array ( a) ,
353- ColumnarValue :: Array ( b) ,
354- ColumnarValue :: Array ( c) ,
355- ] ) ?;
344+ let result = spark_concat_ws_with_arrays (
345+ & [
346+ ColumnarValue :: Array ( sep) ,
347+ ColumnarValue :: Array ( a) ,
348+ ColumnarValue :: Array ( b) ,
349+ ColumnarValue :: Array ( c) ,
350+ ] ,
351+ 1 ,
352+ ) ?;
356353
357354 match result {
358355 ColumnarValue :: Array ( arr) => {
@@ -371,12 +368,15 @@ mod tests {
371368 let b: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ None :: <& str >] ) ) ;
372369 let c: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ Some ( "c" ) ] ) ) ;
373370
374- let result = spark_concat_ws_with_arrays ( & [
375- ColumnarValue :: Array ( sep) ,
376- ColumnarValue :: Array ( a) ,
377- ColumnarValue :: Array ( b) ,
378- ColumnarValue :: Array ( c) ,
379- ] ) ?;
371+ let result = spark_concat_ws_with_arrays (
372+ & [
373+ ColumnarValue :: Array ( sep) ,
374+ ColumnarValue :: Array ( a) ,
375+ ColumnarValue :: Array ( b) ,
376+ ColumnarValue :: Array ( c) ,
377+ ] ,
378+ 1 ,
379+ ) ?;
380380
381381 match result {
382382 ColumnarValue :: Array ( arr) => {
@@ -393,10 +393,10 @@ mod tests {
393393 let sep: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ None :: <& str >] ) ) ;
394394 let a: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ Some ( "a" ) ] ) ) ;
395395
396- let result = spark_concat_ws_with_arrays ( & [
397- ColumnarValue :: Array ( sep) ,
398- ColumnarValue :: Array ( a ) ,
399- ] ) ?;
396+ let result = spark_concat_ws_with_arrays (
397+ & [ ColumnarValue :: Array ( sep) , ColumnarValue :: Array ( a ) ] ,
398+ 1 ,
399+ ) ?;
400400
401401 match result {
402402 ColumnarValue :: Array ( arr) => {
@@ -414,10 +414,10 @@ mod tests {
414414 let list = make_list_array ( vec ! [ Some ( vec![ Some ( "a" ) , Some ( "b" ) , Some ( "c" ) ] ) ] ) ;
415415 let list_ref: ArrayRef = Arc :: new ( list) ;
416416
417- let result = spark_concat_ws_with_arrays ( & [
418- ColumnarValue :: Array ( sep) ,
419- ColumnarValue :: Array ( list_ref ) ,
420- ] ) ?;
417+ let result = spark_concat_ws_with_arrays (
418+ & [ ColumnarValue :: Array ( sep) , ColumnarValue :: Array ( list_ref ) ] ,
419+ 1 ,
420+ ) ?;
421421
422422 match result {
423423 ColumnarValue :: Array ( arr) => {
@@ -435,10 +435,10 @@ mod tests {
435435 let list = make_list_array ( vec ! [ Some ( vec![ Some ( "a" ) , None , Some ( "c" ) ] ) ] ) ;
436436 let list_ref: ArrayRef = Arc :: new ( list) ;
437437
438- let result = spark_concat_ws_with_arrays ( & [
439- ColumnarValue :: Array ( sep) ,
440- ColumnarValue :: Array ( list_ref ) ,
441- ] ) ?;
438+ let result = spark_concat_ws_with_arrays (
439+ & [ ColumnarValue :: Array ( sep) , ColumnarValue :: Array ( list_ref ) ] ,
440+ 1 ,
441+ ) ?;
442442
443443 match result {
444444 ColumnarValue :: Array ( arr) => {
@@ -458,12 +458,15 @@ mod tests {
458458 let list_ref: ArrayRef = Arc :: new ( list) ;
459459 let y: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "y" ] ) ) ;
460460
461- let result = spark_concat_ws_with_arrays ( & [
462- ColumnarValue :: Array ( sep) ,
463- ColumnarValue :: Array ( x) ,
464- ColumnarValue :: Array ( list_ref) ,
465- ColumnarValue :: Array ( y) ,
466- ] ) ?;
461+ let result = spark_concat_ws_with_arrays (
462+ & [
463+ ColumnarValue :: Array ( sep) ,
464+ ColumnarValue :: Array ( x) ,
465+ ColumnarValue :: Array ( list_ref) ,
466+ ColumnarValue :: Array ( y) ,
467+ ] ,
468+ 1 ,
469+ ) ?;
467470
468471 match result {
469472 ColumnarValue :: Array ( arr) => {
@@ -482,11 +485,14 @@ mod tests {
482485 let b: ArrayRef =
483486 Arc :: new ( StringArray :: from ( vec ! [ Some ( "b" ) , Some ( "y" ) , Some ( "z" ) ] ) ) ;
484487
485- let result = spark_concat_ws_with_arrays ( & [
486- ColumnarValue :: Array ( sep) ,
487- ColumnarValue :: Array ( a) ,
488- ColumnarValue :: Array ( b) ,
489- ] ) ?;
488+ let result = spark_concat_ws_with_arrays (
489+ & [
490+ ColumnarValue :: Array ( sep) ,
491+ ColumnarValue :: Array ( a) ,
492+ ColumnarValue :: Array ( b) ,
493+ ] ,
494+ 3 ,
495+ ) ?;
490496
491497 match result {
492498 ColumnarValue :: Array ( arr) => {
0 commit comments