@@ -23,7 +23,9 @@ use std::sync::Arc;
2323
2424use crate :: array_agg:: ArrayAgg ;
2525
26- use arrow:: array:: { Array , ArrayRef , AsArray , BooleanArray , LargeStringArray } ;
26+ use arrow:: array:: {
27+ Array , ArrayAccessor , ArrayRef , AsArray , BooleanArray , LargeStringArray ,
28+ } ;
2729use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2830use datafusion_common:: cast:: { as_generic_string_array, as_string_view_array} ;
2931use datafusion_common:: {
@@ -354,21 +356,24 @@ impl StringAggGroupsAccumulator {
354356 let mut retained_batches = Vec :: with_capacity ( self . batches . len ( ) ) ;
355357 let mut retained_entries = Vec :: with_capacity ( self . batch_entries . len ( ) ) ;
356358
357- for ( batch, mut entries) in
358- self . batches . drain ( ..) . zip ( self . batch_entries . drain ( ..) )
359- {
360- entries. retain ( |( group_idx, _) | * group_idx >= emit_groups) ;
359+ for ( batch, entries) in self . batches . drain ( ..) . zip ( self . batch_entries . drain ( ..) ) {
360+ let entries: Vec < _ > = entries
361+ . into_iter ( )
362+ . filter_map ( |( group_idx, row_idx) | {
363+ if group_idx >= emit_groups {
364+ Some ( ( group_idx - emit_groups, row_idx) )
365+ } else {
366+ None
367+ }
368+ } )
369+ . collect ( ) ;
361370 if entries. is_empty ( ) {
362371 continue ;
363372 }
364373
365374 // Keep the original arrays for this prototype and only renumber
366375 // retained groups. SUB_ISSUE_04 will compact mixed batches so
367376 // partially emitted batches no longer pin their full inputs.
368- for ( group_idx, _) in & mut entries {
369- * group_idx -= emit_groups;
370- }
371-
372377 retained_batches. push ( batch) ;
373378 retained_entries. push ( entries) ;
374379 }
@@ -379,21 +384,51 @@ impl StringAggGroupsAccumulator {
379384 }
380385
381386 fn append_rows < ' a > (
382- & mut self ,
383387 iter : impl Iterator < Item = Option < & ' a str > > ,
384388 group_indices : & [ usize ] ,
385389 ) -> Vec < ( u32 , u32 ) > {
386- let mut entries = Vec :: new ( ) ;
390+ iter. zip ( group_indices. iter ( ) )
391+ . enumerate ( )
392+ . filter_map ( |( row_idx, ( opt_value, & group_idx) ) | {
393+ opt_value. map ( |_| ( group_idx as u32 , row_idx as u32 ) )
394+ } )
395+ . collect ( )
396+ }
387397
388- for ( row_idx, ( opt_value, & group_idx) ) in
389- iter. zip ( group_indices. iter ( ) ) . enumerate ( )
390- {
391- if opt_value. is_some ( ) {
392- entries. push ( ( group_idx as u32 , row_idx as u32 ) ) ;
398+ fn append_value (
399+ values : & mut [ Option < String > ] ,
400+ group_idx : usize ,
401+ value : & str ,
402+ delimiter : & str ,
403+ ) {
404+ match & mut values[ group_idx] {
405+ Some ( existing) => {
406+ existing. push_str ( delimiter) ;
407+ existing. push_str ( value) ;
393408 }
409+ slot @ None => * slot = Some ( value. to_string ( ) ) ,
394410 }
411+ }
395412
396- entries
413+ fn append_batch_values_typed < ' a , A > (
414+ values : & mut [ Option < String > ] ,
415+ entries : & [ ( u32 , u32 ) ] ,
416+ array : A ,
417+ delimiter : & str ,
418+ emit_groups : usize ,
419+ ) where
420+ A : ArrayAccessor < Item = & ' a str > ,
421+ {
422+ for & ( group_idx, row_idx) in entries {
423+ let group_idx = group_idx as usize ;
424+ if group_idx >= emit_groups {
425+ continue ;
426+ }
427+
428+ let row_idx = row_idx as usize ;
429+ debug_assert ! ( !array. is_null( row_idx) ) ;
430+ Self :: append_value ( values, group_idx, array. value ( row_idx) , delimiter) ;
431+ }
397432 }
398433
399434 fn append_batch_values (
@@ -403,48 +438,28 @@ impl StringAggGroupsAccumulator {
403438 delimiter : & str ,
404439 emit_groups : usize ,
405440 ) -> Result < ( ) > {
406- let append_value =
407- |values : & mut [ Option < String > ] , group_idx : usize , value : & str | {
408- match & mut values[ group_idx] {
409- Some ( existing) => {
410- existing. push_str ( delimiter) ;
411- existing. push_str ( value) ;
412- }
413- slot @ None => * slot = Some ( value. to_string ( ) ) ,
414- }
415- } ;
416-
417441 match array. data_type ( ) {
418- DataType :: Utf8 => {
419- let array = array. as_string :: < i32 > ( ) ;
420- for & ( group_idx, row_idx) in entries {
421- let group_idx = group_idx as usize ;
422- if group_idx >= emit_groups || array. is_null ( row_idx as usize ) {
423- continue ;
424- }
425- append_value ( values, group_idx, array. value ( row_idx as usize ) ) ;
426- }
427- }
428- DataType :: LargeUtf8 => {
429- let array = array. as_string :: < i64 > ( ) ;
430- for & ( group_idx, row_idx) in entries {
431- let group_idx = group_idx as usize ;
432- if group_idx >= emit_groups || array. is_null ( row_idx as usize ) {
433- continue ;
434- }
435- append_value ( values, group_idx, array. value ( row_idx as usize ) ) ;
436- }
437- }
438- DataType :: Utf8View => {
439- let array = array. as_string_view ( ) ;
440- for & ( group_idx, row_idx) in entries {
441- let group_idx = group_idx as usize ;
442- if group_idx >= emit_groups || array. is_null ( row_idx as usize ) {
443- continue ;
444- }
445- append_value ( values, group_idx, array. value ( row_idx as usize ) ) ;
446- }
447- }
442+ DataType :: Utf8 => Self :: append_batch_values_typed (
443+ values,
444+ entries,
445+ array. as_string :: < i32 > ( ) ,
446+ delimiter,
447+ emit_groups,
448+ ) ,
449+ DataType :: LargeUtf8 => Self :: append_batch_values_typed (
450+ values,
451+ entries,
452+ array. as_string :: < i64 > ( ) ,
453+ delimiter,
454+ emit_groups,
455+ ) ,
456+ DataType :: Utf8View => Self :: append_batch_values_typed (
457+ values,
458+ entries,
459+ array. as_string_view ( ) ,
460+ delimiter,
461+ emit_groups,
462+ ) ,
448463 other => return internal_err ! ( "string_agg unexpected data type: {other}" ) ,
449464 }
450465
@@ -465,13 +480,13 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
465480
466481 let entries = match array. data_type ( ) {
467482 DataType :: Utf8 => {
468- self . append_rows ( array. as_string :: < i32 > ( ) . iter ( ) , group_indices)
483+ Self :: append_rows ( array. as_string :: < i32 > ( ) . iter ( ) , group_indices)
469484 }
470485 DataType :: LargeUtf8 => {
471- self . append_rows ( array. as_string :: < i64 > ( ) . iter ( ) , group_indices)
486+ Self :: append_rows ( array. as_string :: < i64 > ( ) . iter ( ) , group_indices)
472487 }
473488 DataType :: Utf8View => {
474- self . append_rows ( array. as_string_view ( ) . iter ( ) , group_indices)
489+ Self :: append_rows ( array. as_string_view ( ) . iter ( ) , group_indices)
475490 }
476491 other => return internal_err ! ( "string_agg unexpected data type: {other}" ) ,
477492 } ;
@@ -507,12 +522,11 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
507522 EmitTo :: First ( _) => self . retain_after_emit ( emit_groups) ,
508523 }
509524
510- let result: ArrayRef = Arc :: new ( LargeStringArray :: from ( to_emit) ) ;
511- Ok ( result)
525+ Ok ( Arc :: new ( LargeStringArray :: from ( to_emit) ) )
512526 }
513527
514528 fn state ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
515- self . evaluate ( emit_to) . map ( |arr| vec ! [ arr ] )
529+ Ok ( vec ! [ self . evaluate( emit_to) ? ] )
516530 }
517531
518532 fn merge_batch (
0 commit comments