@@ -24,7 +24,8 @@ use std::sync::Arc;
2424use crate :: array_agg:: ArrayAgg ;
2525
2626use arrow:: array:: {
27- Array , ArrayAccessor , ArrayRef , AsArray , BooleanArray , LargeStringArray ,
27+ Array , ArrayAccessor , ArrayRef , AsArray , BooleanArray , GenericStringArray ,
28+ LargeStringArray , StringArrayType , StringViewArray ,
2829} ;
2930use arrow:: datatypes:: { DataType , Field , FieldRef } ;
3031use datafusion_common:: cast:: { as_generic_string_array, as_string_view_array} ;
@@ -333,6 +334,73 @@ struct StringAggGroupsAccumulator {
333334 num_groups : usize ,
334335}
335336
337+ enum StringInputArray < ' a > {
338+ Utf8 ( & ' a GenericStringArray < i32 > ) ,
339+ LargeUtf8 ( & ' a GenericStringArray < i64 > ) ,
340+ Utf8View ( & ' a StringViewArray ) ,
341+ }
342+
343+ impl < ' a > StringInputArray < ' a > {
344+ fn try_new ( array : & ' a ArrayRef ) -> Result < Self > {
345+ match array. data_type ( ) {
346+ DataType :: Utf8 => Ok ( Self :: Utf8 ( array. as_string :: < i32 > ( ) ) ) ,
347+ DataType :: LargeUtf8 => Ok ( Self :: LargeUtf8 ( array. as_string :: < i64 > ( ) ) ) ,
348+ DataType :: Utf8View => Ok ( Self :: Utf8View ( array. as_string_view ( ) ) ) ,
349+ other => internal_err ! ( "string_agg unexpected data type: {other}" ) ,
350+ }
351+ }
352+
353+ fn append_rows ( & self , group_indices : & [ usize ] ) -> Vec < ( u32 , u32 ) > {
354+ match self {
355+ Self :: Utf8 ( array) => {
356+ StringAggGroupsAccumulator :: append_rows_typed ( * array, group_indices)
357+ }
358+ Self :: LargeUtf8 ( array) => {
359+ StringAggGroupsAccumulator :: append_rows_typed ( * array, group_indices)
360+ }
361+ Self :: Utf8View ( array) => {
362+ StringAggGroupsAccumulator :: append_rows_typed ( * array, group_indices)
363+ }
364+ }
365+ }
366+
367+ fn append_batch_values (
368+ & self ,
369+ values : & mut [ Option < String > ] ,
370+ entries : & [ ( u32 , u32 ) ] ,
371+ delimiter : & str ,
372+ emit_groups : usize ,
373+ ) {
374+ match self {
375+ Self :: Utf8 ( array) => StringAggGroupsAccumulator :: append_batch_values_typed (
376+ values,
377+ entries,
378+ * array,
379+ delimiter,
380+ emit_groups,
381+ ) ,
382+ Self :: LargeUtf8 ( array) => {
383+ StringAggGroupsAccumulator :: append_batch_values_typed (
384+ values,
385+ entries,
386+ * array,
387+ delimiter,
388+ emit_groups,
389+ )
390+ }
391+ Self :: Utf8View ( array) => {
392+ StringAggGroupsAccumulator :: append_batch_values_typed (
393+ values,
394+ entries,
395+ * array,
396+ delimiter,
397+ emit_groups,
398+ )
399+ }
400+ }
401+ }
402+ }
403+
336404impl StringAggGroupsAccumulator {
337405 fn new ( delimiter : String ) -> Self {
338406 Self {
@@ -383,19 +451,21 @@ impl StringAggGroupsAccumulator {
383451 self . num_groups -= emit_groups as usize ;
384452 }
385453
386- fn append_rows < ' a > (
387- iter : impl Iterator < Item = Option < & ' a str > > ,
388- group_indices : & [ usize ] ,
389- ) -> Vec < ( u32 , u32 ) > {
390- iter. zip ( group_indices. iter ( ) )
454+ fn append_rows_typed < ' a , A > ( array : A , group_indices : & [ usize ] ) -> Vec < ( u32 , u32 ) >
455+ where
456+ A : StringArrayType < ' a > ,
457+ {
458+ array
459+ . iter ( )
460+ . zip ( group_indices. iter ( ) )
391461 . enumerate ( )
392462 . filter_map ( |( row_idx, ( opt_value, & group_idx) ) | {
393463 opt_value. map ( |_| ( group_idx as u32 , row_idx as u32 ) )
394464 } )
395465 . collect ( )
396466 }
397467
398- fn append_value (
468+ fn append_group_value (
399469 values : & mut [ Option < String > ] ,
400470 group_idx : usize ,
401471 value : & str ,
@@ -427,7 +497,7 @@ impl StringAggGroupsAccumulator {
427497
428498 let row_idx = row_idx as usize ;
429499 debug_assert ! ( !array. is_null( row_idx) ) ;
430- Self :: append_value ( values, group_idx, array. value ( row_idx) , delimiter) ;
500+ Self :: append_group_value ( values, group_idx, array. value ( row_idx) , delimiter) ;
431501 }
432502 }
433503
@@ -438,31 +508,12 @@ impl StringAggGroupsAccumulator {
438508 delimiter : & str ,
439509 emit_groups : usize ,
440510 ) -> Result < ( ) > {
441- match array. data_type ( ) {
442- DataType :: Utf8 => Self :: append_batch_values_typed (
443- values,
444- entries,
445- array. as_string :: < i32 > ( ) ,
446- delimiter,
447- emit_groups,
448- ) ,
449- DataType :: LargeUtf8 => Self :: append_batch_values_typed (
450- values,
451- entries,
452- array. as_string :: < i64 > ( ) ,
453- delimiter,
454- emit_groups,
455- ) ,
456- DataType :: Utf8View => Self :: append_batch_values_typed (
457- values,
458- entries,
459- array. as_string_view ( ) ,
460- delimiter,
461- emit_groups,
462- ) ,
463- other => return internal_err ! ( "string_agg unexpected data type: {other}" ) ,
464- }
465-
511+ StringInputArray :: try_new ( array) ?. append_batch_values (
512+ values,
513+ entries,
514+ delimiter,
515+ emit_groups,
516+ ) ;
466517 Ok ( ( ) )
467518 }
468519}
@@ -477,19 +528,7 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
477528 ) -> Result < ( ) > {
478529 self . num_groups = self . num_groups . max ( total_num_groups) ;
479530 let array = apply_filter_as_nulls ( & values[ 0 ] , opt_filter) ?;
480-
481- let entries = match array. data_type ( ) {
482- DataType :: Utf8 => {
483- Self :: append_rows ( array. as_string :: < i32 > ( ) . iter ( ) , group_indices)
484- }
485- DataType :: LargeUtf8 => {
486- Self :: append_rows ( array. as_string :: < i64 > ( ) . iter ( ) , group_indices)
487- }
488- DataType :: Utf8View => {
489- Self :: append_rows ( array. as_string_view ( ) . iter ( ) , group_indices)
490- }
491- other => return internal_err ! ( "string_agg unexpected data type: {other}" ) ,
492- } ;
531+ let entries = StringInputArray :: try_new ( & array) ?. append_rows ( group_indices) ;
493532
494533 if !entries. is_empty ( ) {
495534 self . batches . push ( array) ;
0 commit comments