Skip to content

Commit 1cc0397

Browse files
committed
Refactor row handling and simplify logic
Remove unnecessary &mut self from append_rows. Consolidate repeated string-append loop into a typed private helper using ArrayAccessor. Eliminate redundant runtime null checks in favor of non-null entry invariant with debug_assert!. Simplify retain_after_emit into a single filter-and-renumber pass. Trim local ceremony in evaluate() and state() for clarity.
1 parent eb7b3c8 commit 1cc0397

1 file changed

Lines changed: 78 additions & 64 deletions

File tree

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 78 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ use std::sync::Arc;
2323

2424
use crate::array_agg::ArrayAgg;
2525

26-
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, LargeStringArray};
26+
use arrow::array::{
27+
Array, ArrayAccessor, ArrayRef, AsArray, BooleanArray, LargeStringArray,
28+
};
2729
use arrow::datatypes::{DataType, Field, FieldRef};
2830
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
2931
use datafusion_common::{
@@ -354,21 +356,24 @@ impl StringAggGroupsAccumulator {
354356
let mut retained_batches = Vec::with_capacity(self.batches.len());
355357
let mut retained_entries = Vec::with_capacity(self.batch_entries.len());
356358

357-
for (batch, mut entries) in
358-
self.batches.drain(..).zip(self.batch_entries.drain(..))
359-
{
360-
entries.retain(|(group_idx, _)| *group_idx >= emit_groups);
359+
for (batch, entries) in self.batches.drain(..).zip(self.batch_entries.drain(..)) {
360+
let entries: Vec<_> = entries
361+
.into_iter()
362+
.filter_map(|(group_idx, row_idx)| {
363+
if group_idx >= emit_groups {
364+
Some((group_idx - emit_groups, row_idx))
365+
} else {
366+
None
367+
}
368+
})
369+
.collect();
361370
if entries.is_empty() {
362371
continue;
363372
}
364373

365374
// Keep the original arrays for this prototype and only renumber
366375
// retained groups. SUB_ISSUE_04 will compact mixed batches so
367376
// partially emitted batches no longer pin their full inputs.
368-
for (group_idx, _) in &mut entries {
369-
*group_idx -= emit_groups;
370-
}
371-
372377
retained_batches.push(batch);
373378
retained_entries.push(entries);
374379
}
@@ -379,21 +384,51 @@ impl StringAggGroupsAccumulator {
379384
}
380385

381386
fn append_rows<'a>(
382-
&mut self,
383387
iter: impl Iterator<Item = Option<&'a str>>,
384388
group_indices: &[usize],
385389
) -> Vec<(u32, u32)> {
386-
let mut entries = Vec::new();
390+
iter.zip(group_indices.iter())
391+
.enumerate()
392+
.filter_map(|(row_idx, (opt_value, &group_idx))| {
393+
opt_value.map(|_| (group_idx as u32, row_idx as u32))
394+
})
395+
.collect()
396+
}
387397

388-
for (row_idx, (opt_value, &group_idx)) in
389-
iter.zip(group_indices.iter()).enumerate()
390-
{
391-
if opt_value.is_some() {
392-
entries.push((group_idx as u32, row_idx as u32));
398+
fn append_value(
399+
values: &mut [Option<String>],
400+
group_idx: usize,
401+
value: &str,
402+
delimiter: &str,
403+
) {
404+
match &mut values[group_idx] {
405+
Some(existing) => {
406+
existing.push_str(delimiter);
407+
existing.push_str(value);
393408
}
409+
slot @ None => *slot = Some(value.to_string()),
394410
}
411+
}
395412

396-
entries
413+
fn append_batch_values_typed<'a, A>(
414+
values: &mut [Option<String>],
415+
entries: &[(u32, u32)],
416+
array: A,
417+
delimiter: &str,
418+
emit_groups: usize,
419+
) where
420+
A: ArrayAccessor<Item = &'a str>,
421+
{
422+
for &(group_idx, row_idx) in entries {
423+
let group_idx = group_idx as usize;
424+
if group_idx >= emit_groups {
425+
continue;
426+
}
427+
428+
let row_idx = row_idx as usize;
429+
debug_assert!(!array.is_null(row_idx));
430+
Self::append_value(values, group_idx, array.value(row_idx), delimiter);
431+
}
397432
}
398433

399434
fn append_batch_values(
@@ -403,48 +438,28 @@ impl StringAggGroupsAccumulator {
403438
delimiter: &str,
404439
emit_groups: usize,
405440
) -> Result<()> {
406-
let append_value =
407-
|values: &mut [Option<String>], group_idx: usize, value: &str| {
408-
match &mut values[group_idx] {
409-
Some(existing) => {
410-
existing.push_str(delimiter);
411-
existing.push_str(value);
412-
}
413-
slot @ None => *slot = Some(value.to_string()),
414-
}
415-
};
416-
417441
match array.data_type() {
418-
DataType::Utf8 => {
419-
let array = array.as_string::<i32>();
420-
for &(group_idx, row_idx) in entries {
421-
let group_idx = group_idx as usize;
422-
if group_idx >= emit_groups || array.is_null(row_idx as usize) {
423-
continue;
424-
}
425-
append_value(values, group_idx, array.value(row_idx as usize));
426-
}
427-
}
428-
DataType::LargeUtf8 => {
429-
let array = array.as_string::<i64>();
430-
for &(group_idx, row_idx) in entries {
431-
let group_idx = group_idx as usize;
432-
if group_idx >= emit_groups || array.is_null(row_idx as usize) {
433-
continue;
434-
}
435-
append_value(values, group_idx, array.value(row_idx as usize));
436-
}
437-
}
438-
DataType::Utf8View => {
439-
let array = array.as_string_view();
440-
for &(group_idx, row_idx) in entries {
441-
let group_idx = group_idx as usize;
442-
if group_idx >= emit_groups || array.is_null(row_idx as usize) {
443-
continue;
444-
}
445-
append_value(values, group_idx, array.value(row_idx as usize));
446-
}
447-
}
442+
DataType::Utf8 => Self::append_batch_values_typed(
443+
values,
444+
entries,
445+
array.as_string::<i32>(),
446+
delimiter,
447+
emit_groups,
448+
),
449+
DataType::LargeUtf8 => Self::append_batch_values_typed(
450+
values,
451+
entries,
452+
array.as_string::<i64>(),
453+
delimiter,
454+
emit_groups,
455+
),
456+
DataType::Utf8View => Self::append_batch_values_typed(
457+
values,
458+
entries,
459+
array.as_string_view(),
460+
delimiter,
461+
emit_groups,
462+
),
448463
other => return internal_err!("string_agg unexpected data type: {other}"),
449464
}
450465

@@ -465,13 +480,13 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
465480

466481
let entries = match array.data_type() {
467482
DataType::Utf8 => {
468-
self.append_rows(array.as_string::<i32>().iter(), group_indices)
483+
Self::append_rows(array.as_string::<i32>().iter(), group_indices)
469484
}
470485
DataType::LargeUtf8 => {
471-
self.append_rows(array.as_string::<i64>().iter(), group_indices)
486+
Self::append_rows(array.as_string::<i64>().iter(), group_indices)
472487
}
473488
DataType::Utf8View => {
474-
self.append_rows(array.as_string_view().iter(), group_indices)
489+
Self::append_rows(array.as_string_view().iter(), group_indices)
475490
}
476491
other => return internal_err!("string_agg unexpected data type: {other}"),
477492
};
@@ -507,12 +522,11 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
507522
EmitTo::First(_) => self.retain_after_emit(emit_groups),
508523
}
509524

510-
let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit));
511-
Ok(result)
525+
Ok(Arc::new(LargeStringArray::from(to_emit)))
512526
}
513527

514528
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
515-
self.evaluate(emit_to).map(|arr| vec![arr])
529+
Ok(vec![self.evaluate(emit_to)?])
516530
}
517531

518532
fn merge_batch(

0 commit comments

Comments
 (0)