Skip to content

Commit 3e481f8

Browse files
committed
Refactor string_agg handling and rename slot appender
Consolidate string-like array routing through a single StringInputArray abstraction to improve maintainability. Rename the slot appender to append_group_value for better readability of the lazy-assembly path.
1 parent 1cc0397 commit 3e481f8

1 file changed

Lines changed: 85 additions & 46 deletions

File tree

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use std::sync::Arc;
2424
use crate::array_agg::ArrayAgg;
2525

2626
use arrow::array::{
27-
Array, ArrayAccessor, ArrayRef, AsArray, BooleanArray, LargeStringArray,
27+
Array, ArrayAccessor, ArrayRef, AsArray, BooleanArray, GenericStringArray,
28+
LargeStringArray, StringArrayType, StringViewArray,
2829
};
2930
use arrow::datatypes::{DataType, Field, FieldRef};
3031
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
@@ -333,6 +334,73 @@ struct StringAggGroupsAccumulator {
333334
num_groups: usize,
334335
}
335336

337+
enum StringInputArray<'a> {
338+
Utf8(&'a GenericStringArray<i32>),
339+
LargeUtf8(&'a GenericStringArray<i64>),
340+
Utf8View(&'a StringViewArray),
341+
}
342+
343+
impl<'a> StringInputArray<'a> {
344+
fn try_new(array: &'a ArrayRef) -> Result<Self> {
345+
match array.data_type() {
346+
DataType::Utf8 => Ok(Self::Utf8(array.as_string::<i32>())),
347+
DataType::LargeUtf8 => Ok(Self::LargeUtf8(array.as_string::<i64>())),
348+
DataType::Utf8View => Ok(Self::Utf8View(array.as_string_view())),
349+
other => internal_err!("string_agg unexpected data type: {other}"),
350+
}
351+
}
352+
353+
fn append_rows(&self, group_indices: &[usize]) -> Vec<(u32, u32)> {
354+
match self {
355+
Self::Utf8(array) => {
356+
StringAggGroupsAccumulator::append_rows_typed(*array, group_indices)
357+
}
358+
Self::LargeUtf8(array) => {
359+
StringAggGroupsAccumulator::append_rows_typed(*array, group_indices)
360+
}
361+
Self::Utf8View(array) => {
362+
StringAggGroupsAccumulator::append_rows_typed(*array, group_indices)
363+
}
364+
}
365+
}
366+
367+
fn append_batch_values(
368+
&self,
369+
values: &mut [Option<String>],
370+
entries: &[(u32, u32)],
371+
delimiter: &str,
372+
emit_groups: usize,
373+
) {
374+
match self {
375+
Self::Utf8(array) => StringAggGroupsAccumulator::append_batch_values_typed(
376+
values,
377+
entries,
378+
*array,
379+
delimiter,
380+
emit_groups,
381+
),
382+
Self::LargeUtf8(array) => {
383+
StringAggGroupsAccumulator::append_batch_values_typed(
384+
values,
385+
entries,
386+
*array,
387+
delimiter,
388+
emit_groups,
389+
)
390+
}
391+
Self::Utf8View(array) => {
392+
StringAggGroupsAccumulator::append_batch_values_typed(
393+
values,
394+
entries,
395+
*array,
396+
delimiter,
397+
emit_groups,
398+
)
399+
}
400+
}
401+
}
402+
}
403+
336404
impl StringAggGroupsAccumulator {
337405
fn new(delimiter: String) -> Self {
338406
Self {
@@ -383,19 +451,21 @@ impl StringAggGroupsAccumulator {
383451
self.num_groups -= emit_groups as usize;
384452
}
385453

386-
fn append_rows<'a>(
387-
iter: impl Iterator<Item = Option<&'a str>>,
388-
group_indices: &[usize],
389-
) -> Vec<(u32, u32)> {
390-
iter.zip(group_indices.iter())
454+
fn append_rows_typed<'a, A>(array: A, group_indices: &[usize]) -> Vec<(u32, u32)>
455+
where
456+
A: StringArrayType<'a>,
457+
{
458+
array
459+
.iter()
460+
.zip(group_indices.iter())
391461
.enumerate()
392462
.filter_map(|(row_idx, (opt_value, &group_idx))| {
393463
opt_value.map(|_| (group_idx as u32, row_idx as u32))
394464
})
395465
.collect()
396466
}
397467

398-
fn append_value(
468+
fn append_group_value(
399469
values: &mut [Option<String>],
400470
group_idx: usize,
401471
value: &str,
@@ -427,7 +497,7 @@ impl StringAggGroupsAccumulator {
427497

428498
let row_idx = row_idx as usize;
429499
debug_assert!(!array.is_null(row_idx));
430-
Self::append_value(values, group_idx, array.value(row_idx), delimiter);
500+
Self::append_group_value(values, group_idx, array.value(row_idx), delimiter);
431501
}
432502
}
433503

@@ -438,31 +508,12 @@ impl StringAggGroupsAccumulator {
438508
delimiter: &str,
439509
emit_groups: usize,
440510
) -> Result<()> {
441-
match array.data_type() {
442-
DataType::Utf8 => Self::append_batch_values_typed(
443-
values,
444-
entries,
445-
array.as_string::<i32>(),
446-
delimiter,
447-
emit_groups,
448-
),
449-
DataType::LargeUtf8 => Self::append_batch_values_typed(
450-
values,
451-
entries,
452-
array.as_string::<i64>(),
453-
delimiter,
454-
emit_groups,
455-
),
456-
DataType::Utf8View => Self::append_batch_values_typed(
457-
values,
458-
entries,
459-
array.as_string_view(),
460-
delimiter,
461-
emit_groups,
462-
),
463-
other => return internal_err!("string_agg unexpected data type: {other}"),
464-
}
465-
511+
StringInputArray::try_new(array)?.append_batch_values(
512+
values,
513+
entries,
514+
delimiter,
515+
emit_groups,
516+
);
466517
Ok(())
467518
}
468519
}
@@ -477,19 +528,7 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
477528
) -> Result<()> {
478529
self.num_groups = self.num_groups.max(total_num_groups);
479530
let array = apply_filter_as_nulls(&values[0], opt_filter)?;
480-
481-
let entries = match array.data_type() {
482-
DataType::Utf8 => {
483-
Self::append_rows(array.as_string::<i32>().iter(), group_indices)
484-
}
485-
DataType::LargeUtf8 => {
486-
Self::append_rows(array.as_string::<i64>().iter(), group_indices)
487-
}
488-
DataType::Utf8View => {
489-
Self::append_rows(array.as_string_view().iter(), group_indices)
490-
}
491-
other => return internal_err!("string_agg unexpected data type: {other}"),
492-
};
531+
let entries = StringInputArray::try_new(&array)?.append_rows(group_indices);
493532

494533
if !entries.is_empty() {
495534
self.batches.push(array);

0 commit comments

Comments
 (0)