Skip to content

Commit fc5345f

Browse files
committed
Optimize string_agg for hybrid eager and deferred modes
Adjust string_agg to implement a hybrid accumulator, offering eager updates for lightweight workloads and switching to deferred row tracking for larger batches. This change enhances performance while maintaining efficiency. Included mixed-mode regression tests to cover various batch scenarios and ensure correctness.
1 parent 8be4df1 commit fc5345f

1 file changed

Lines changed: 154 additions & 20 deletions

File tree

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 154 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! [`StringAgg`] accumulator for the `string_agg` function
1919
2020
use std::hash::Hash;
21-
use std::mem::size_of_val;
21+
use std::mem::{size_of, size_of_val};
2222
use std::sync::Arc;
2323

2424
use crate::array_agg::ArrayAgg;
@@ -326,6 +326,10 @@ fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
326326
struct StringAggGroupsAccumulator {
327327
/// The delimiter placed between concatenated values.
328328
delimiter: String,
329+
/// Materialized string state for groups that use the eager fast path.
330+
values: Vec<Option<String>>,
331+
/// Running total of bytes stored in `values`.
332+
total_data_bytes: usize,
329333
/// Source arrays retained from input batches or merged state batches.
330334
batches: Vec<ArrayRef>,
331335
/// Per-batch `(group_idx, row_idx)` pairs for non-null rows.
@@ -341,6 +345,14 @@ enum StringInputArray<'a> {
341345
}
342346

343347
impl<'a> StringInputArray<'a> {
348+
fn sample_non_null_len(&self) -> Option<usize> {
349+
match self {
350+
Self::Utf8(array) => array.iter().flatten().next().map(str::len),
351+
Self::LargeUtf8(array) => array.iter().flatten().next().map(str::len),
352+
Self::Utf8View(array) => array.iter().flatten().next().map(str::len),
353+
}
354+
}
355+
344356
fn try_new(array: &'a ArrayRef) -> Result<Self> {
345357
match array.data_type() {
346358
DataType::Utf8 => Ok(Self::Utf8(array.as_string::<i32>())),
@@ -364,6 +376,34 @@ impl<'a> StringInputArray<'a> {
364376
}
365377
}
366378

379+
fn append_materialized(
380+
&self,
381+
values: &mut [Option<String>],
382+
group_indices: &[usize],
383+
delimiter: &str,
384+
) -> usize {
385+
match self {
386+
Self::Utf8(array) => StringAggGroupsAccumulator::append_batch_typed(
387+
values,
388+
array.iter(),
389+
group_indices,
390+
delimiter,
391+
),
392+
Self::LargeUtf8(array) => StringAggGroupsAccumulator::append_batch_typed(
393+
values,
394+
array.iter(),
395+
group_indices,
396+
delimiter,
397+
),
398+
Self::Utf8View(array) => StringAggGroupsAccumulator::append_batch_typed(
399+
values,
400+
array.iter(),
401+
group_indices,
402+
delimiter,
403+
),
404+
}
405+
}
406+
367407
fn append_batch_values(
368408
&self,
369409
values: &mut [Option<String>],
@@ -402,9 +442,14 @@ impl<'a> StringInputArray<'a> {
402442
}
403443

404444
impl StringAggGroupsAccumulator {
445+
const DEFER_GROUP_THRESHOLD: usize = 32;
446+
const DEFER_PAYLOAD_LEN_THRESHOLD: usize = 32;
447+
405448
fn new(delimiter: String) -> Self {
406449
Self {
407450
delimiter,
451+
values: Vec::new(),
452+
total_data_bytes: 0,
408453
batches: Vec::new(),
409454
batch_entries: Vec::new(),
410455
num_groups: 0,
@@ -414,6 +459,8 @@ impl StringAggGroupsAccumulator {
414459
fn clear_state(&mut self) {
415460
// `size()` measures Vec capacity rather than len, so allocate new
416461
// buffers instead of using `clear()`.
462+
self.values = Vec::new();
463+
self.total_data_bytes = 0;
417464
self.batches = Vec::new();
418465
self.batch_entries = Vec::new();
419466
self.num_groups = 0;
@@ -448,7 +495,6 @@ impl StringAggGroupsAccumulator {
448495

449496
self.batches = retained_batches;
450497
self.batch_entries = retained_entries;
451-
self.num_groups -= emit_groups as usize;
452498
}
453499

454500
fn append_rows_typed<'a, A>(array: &A, group_indices: &[usize]) -> Vec<(u32, u32)>
@@ -470,16 +516,40 @@ impl StringAggGroupsAccumulator {
470516
group_idx: usize,
471517
value: &str,
472518
delimiter: &str,
473-
) {
519+
) -> usize {
474520
match &mut values[group_idx] {
475521
Some(existing) => {
522+
let added = delimiter.len() + value.len();
523+
existing.reserve(added);
476524
existing.push_str(delimiter);
477525
existing.push_str(value);
526+
added
527+
}
528+
slot @ None => {
529+
*slot = Some(value.to_string());
530+
value.len()
478531
}
479-
slot @ None => *slot = Some(value.to_string()),
480532
}
481533
}
482534

535+
fn append_batch_typed<'a, I>(
536+
values: &mut [Option<String>],
537+
iter: I,
538+
group_indices: &[usize],
539+
delimiter: &str,
540+
) -> usize
541+
where
542+
I: Iterator<Item = Option<&'a str>>,
543+
{
544+
iter.zip(group_indices.iter())
545+
.filter_map(|(opt_value, &group_idx)| {
546+
opt_value.map(|value| {
547+
Self::append_group_value(values, group_idx, value, delimiter)
548+
})
549+
})
550+
.sum()
551+
}
552+
483553
fn append_batch_values_typed<'a, A>(
484554
values: &mut [Option<String>],
485555
entries: &[(u32, u32)],
@@ -497,7 +567,12 @@ impl StringAggGroupsAccumulator {
497567

498568
let row_idx = row_idx as usize;
499569
debug_assert!(!array.is_null(row_idx));
500-
Self::append_group_value(values, group_idx, array.value(row_idx), delimiter);
570+
let _ = Self::append_group_value(
571+
values,
572+
group_idx,
573+
array.value(row_idx),
574+
delimiter,
575+
);
501576
}
502577
}
503578

@@ -516,6 +591,17 @@ impl StringAggGroupsAccumulator {
516591
);
517592
Ok(())
518593
}
594+
595+
fn should_defer(
596+
&self,
597+
input: &StringInputArray<'_>,
598+
total_num_groups: usize,
599+
) -> bool {
600+
total_num_groups >= Self::DEFER_GROUP_THRESHOLD
601+
&& input
602+
.sample_non_null_len()
603+
.is_some_and(|len| len >= Self::DEFER_PAYLOAD_LEN_THRESHOLD)
604+
}
519605
}
520606

521607
impl GroupsAccumulator for StringAggGroupsAccumulator {
@@ -527,24 +613,35 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
527613
total_num_groups: usize,
528614
) -> Result<()> {
529615
self.num_groups = self.num_groups.max(total_num_groups);
616+
self.values.resize(total_num_groups, None);
530617
let array = apply_filter_as_nulls(&values[0], opt_filter)?;
531-
let entries = StringInputArray::try_new(&array)?.append_rows(group_indices);
618+
let input = StringInputArray::try_new(&array)?;
532619

533-
if !entries.is_empty() {
534-
self.batches.push(array);
535-
self.batch_entries.push(entries);
620+
if self.should_defer(&input, total_num_groups) {
621+
let entries = input.append_rows(group_indices);
622+
if !entries.is_empty() {
623+
self.batches.push(array);
624+
self.batch_entries.push(entries);
625+
}
626+
} else {
627+
self.total_data_bytes += input.append_materialized(
628+
&mut self.values,
629+
group_indices,
630+
&self.delimiter,
631+
);
536632
}
537633

538634
Ok(())
539635
}
540636

541637
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
542-
let emit_groups = match emit_to {
543-
EmitTo::All => self.num_groups,
544-
EmitTo::First(n) => n,
545-
};
546-
547-
let mut to_emit = vec![None; emit_groups];
638+
let mut to_emit = emit_to.take_needed(&mut self.values);
639+
let emit_groups = to_emit.len();
640+
let emitted_bytes: usize = to_emit
641+
.iter()
642+
.filter_map(|opt| opt.as_ref().map(|s| s.len()))
643+
.sum();
644+
self.total_data_bytes -= emitted_bytes;
548645

549646
for (batch, entries) in self.batches.iter().zip(&self.batch_entries) {
550647
Self::append_batch_values(
@@ -558,7 +655,10 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
558655

559656
match emit_to {
560657
EmitTo::All => self.clear_state(),
561-
EmitTo::First(_) => self.retain_after_emit(emit_groups),
658+
EmitTo::First(_) => {
659+
self.retain_after_emit(emit_groups);
660+
self.num_groups = self.values.len();
661+
}
562662
}
563663

564664
Ok(Arc::new(LargeStringArray::from(to_emit)))
@@ -598,10 +698,13 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
598698
}
599699

600700
fn size(&self) -> usize {
601-
self.batches
602-
.iter()
603-
.map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
604-
.sum::<usize>()
701+
self.total_data_bytes
702+
+ self.values.capacity() * size_of::<Option<String>>()
703+
+ self
704+
.batches
705+
.iter()
706+
.map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default())
707+
.sum::<usize>()
605708
+ self.batches.capacity() * size_of::<ArrayRef>()
606709
+ self
607710
.batch_entries
@@ -1120,4 +1223,35 @@ mod tests {
11201223
);
11211224
Ok(())
11221225
}
1226+
1227+
#[test]
1228+
fn groups_mixed_eager_and_deferred_batches() -> Result<()> {
1229+
let mut acc = make_groups_acc(",");
1230+
1231+
let eager_values: ArrayRef =
1232+
Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d"]));
1233+
acc.update_batch(&[eager_values], &[0, 1, 0, 1], None, 40)?;
1234+
1235+
let deferred_values: ArrayRef = Arc::new(LargeStringArray::from(vec![
1236+
"large0_abcdefghijklmnopqrstuvwxyzabcdef",
1237+
"large1_bcdefghijklmnopqrstuvwxyzabcdefg",
1238+
"large2_cdefghijklmnopqrstuvwxyzabcdefgh",
1239+
]));
1240+
acc.update_batch(&[deferred_values], &[0, 1, 39], None, 40)?;
1241+
1242+
let result = evaluate_groups(&mut acc, EmitTo::First(2));
1243+
assert_eq!(
1244+
result,
1245+
vec![
1246+
Some("a,c,large0_abcdefghijklmnopqrstuvwxyzabcdef".to_string()),
1247+
Some("b,d,large1_bcdefghijklmnopqrstuvwxyzabcdefg".to_string()),
1248+
]
1249+
);
1250+
1251+
let remaining = evaluate_groups(&mut acc, EmitTo::All);
1252+
let mut expected = vec![None; 38];
1253+
expected[37] = Some("large2_cdefghijklmnopqrstuvwxyzabcdefgh".to_string());
1254+
assert_eq!(remaining, expected);
1255+
Ok(())
1256+
}
11231257
}

0 commit comments

Comments
 (0)