Skip to content

Commit 85b4ece

Browse files
committed
Revert "Refactor emission handling and simplify logic"
This reverts commit baa8054.
1 parent baa8054 commit 85b4ece

1 file changed

Lines changed: 75 additions & 63 deletions

File tree

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 75 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ struct StringAggGroupsAccumulator {
334334
batches: Vec<ArrayRef>,
335335
/// Per-batch `(group_idx, row_idx)` pairs for non-null rows.
336336
batch_entries: Vec<Vec<(u32, u32)>>,
337+
/// Total number of groups tracked.
338+
num_groups: usize,
337339
}
338340

339341
enum StringInputArray<'a> {
@@ -420,6 +422,7 @@ impl StringAggGroupsAccumulator {
420422
total_data_bytes: 0,
421423
batches: Vec::new(),
422424
batch_entries: Vec::new(),
425+
num_groups: 0,
423426
}
424427
}
425428

@@ -430,32 +433,38 @@ impl StringAggGroupsAccumulator {
430433
self.total_data_bytes = 0;
431434
self.batches = Vec::new();
432435
self.batch_entries = Vec::new();
436+
self.num_groups = 0;
433437
}
434438

435439
fn retain_after_emit(&mut self, emit_groups: usize) {
436440
let emit_groups = emit_groups as u32;
437-
(self.batches, self.batch_entries) = self
438-
.batches
439-
.drain(..)
440-
.zip(self.batch_entries.drain(..))
441-
.filter_map(|(batch, entries)| {
442-
let entries = entries
443-
.into_iter()
444-
.filter_map(|(group_idx, row_idx)| {
445-
if group_idx >= emit_groups {
446-
Some((group_idx - emit_groups, row_idx))
447-
} else {
448-
None
449-
}
450-
})
451-
.collect::<Vec<_>>();
452-
453-
// Keep the original arrays for this prototype and only renumber
454-
// retained groups. SUB_ISSUE_04 will compact mixed batches so
455-
// partially emitted batches no longer pin their full inputs.
456-
(!entries.is_empty()).then_some((batch, entries))
457-
})
458-
.unzip();
441+
let mut retained_batches = Vec::with_capacity(self.batches.len());
442+
let mut retained_entries = Vec::with_capacity(self.batch_entries.len());
443+
444+
for (batch, entries) in self.batches.drain(..).zip(self.batch_entries.drain(..)) {
445+
let entries: Vec<_> = entries
446+
.into_iter()
447+
.filter_map(|(group_idx, row_idx)| {
448+
if group_idx >= emit_groups {
449+
Some((group_idx - emit_groups, row_idx))
450+
} else {
451+
None
452+
}
453+
})
454+
.collect();
455+
if entries.is_empty() {
456+
continue;
457+
}
458+
459+
// Keep the original arrays for this prototype and only renumber
460+
// retained groups. SUB_ISSUE_04 will compact mixed batches so
461+
// partially emitted batches no longer pin their full inputs.
462+
retained_batches.push(batch);
463+
retained_entries.push(entries);
464+
}
465+
466+
self.batches = retained_batches;
467+
self.batch_entries = retained_entries;
459468
}
460469

461470
fn append_rows_typed<'a, A>(array: &A, group_indices: &[usize]) -> Vec<(u32, u32)>
@@ -520,12 +529,12 @@ impl StringAggGroupsAccumulator {
520529
) where
521530
A: ArrayAccessor<Item = &'a str>,
522531
{
523-
for (group_idx, row_idx) in entries
524-
.iter()
525-
.copied()
526-
.filter(|(group_idx, _)| (*group_idx as usize) < emit_groups)
527-
{
532+
for &(group_idx, row_idx) in entries {
528533
let group_idx = group_idx as usize;
534+
if group_idx >= emit_groups {
535+
continue;
536+
}
537+
529538
let row_idx = row_idx as usize;
530539
debug_assert!(!array.is_null(row_idx));
531540
let _ = Self::append_group_value(
@@ -537,42 +546,31 @@ impl StringAggGroupsAccumulator {
537546
}
538547
}
539548

540-
fn should_defer(
541-
&self,
542-
input: &StringInputArray<'_>,
543-
total_num_groups: usize,
544-
) -> bool {
545-
if total_num_groups < Self::DEFER_GROUP_THRESHOLD {
546-
return false;
547-
}
548-
549-
input
550-
.sample_non_null_len()
551-
.is_some_and(|len| len >= Self::DEFER_PAYLOAD_LEN_THRESHOLD)
552-
}
553-
554-
fn replay_deferred_batches(
555-
&self,
549+
fn append_batch_values(
556550
values: &mut [Option<String>],
551+
entries: &[(u32, u32)],
552+
array: &ArrayRef,
553+
delimiter: &str,
557554
emit_groups: usize,
558555
) -> Result<()> {
559-
for (batch, entries) in self.batches.iter().zip(&self.batch_entries) {
560-
StringInputArray::try_new(batch)?.append_batch_values(
561-
values,
562-
entries,
563-
&self.delimiter,
564-
emit_groups,
565-
);
566-
}
567-
556+
StringInputArray::try_new(array)?.append_batch_values(
557+
values,
558+
entries,
559+
delimiter,
560+
emit_groups,
561+
);
568562
Ok(())
569563
}
570564

571-
fn finish_emit(&mut self, emit_to: EmitTo, emit_groups: usize) {
572-
match emit_to {
573-
EmitTo::All => self.clear_state(),
574-
EmitTo::First(_) => self.retain_after_emit(emit_groups),
575-
}
565+
fn should_defer(
566+
&self,
567+
input: &StringInputArray<'_>,
568+
total_num_groups: usize,
569+
) -> bool {
570+
total_num_groups >= Self::DEFER_GROUP_THRESHOLD
571+
&& input
572+
.sample_non_null_len()
573+
.is_some_and(|len| len >= Self::DEFER_PAYLOAD_LEN_THRESHOLD)
576574
}
577575
}
578576

@@ -584,6 +582,7 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
584582
opt_filter: Option<&BooleanArray>,
585583
total_num_groups: usize,
586584
) -> Result<()> {
585+
self.num_groups = self.num_groups.max(total_num_groups);
587586
self.values.resize(total_num_groups, None);
588587
let array = apply_filter_as_nulls(&values[0], opt_filter)?;
589588
let input = StringInputArray::try_new(&array)?;
@@ -606,18 +605,31 @@ impl GroupsAccumulator for StringAggGroupsAccumulator {
606605
}
607606

608607
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
609-
let mut to_emit = match emit_to {
610-
EmitTo::All => std::mem::take(&mut self.values),
611-
EmitTo::First(_) => emit_to.take_needed(&mut self.values),
612-
};
608+
let mut to_emit = emit_to.take_needed(&mut self.values);
613609
let emit_groups = to_emit.len();
614610
let emitted_bytes: usize = to_emit
615611
.iter()
616612
.filter_map(|opt| opt.as_ref().map(|s| s.len()))
617613
.sum();
618614
self.total_data_bytes -= emitted_bytes;
619-
self.replay_deferred_batches(&mut to_emit, emit_groups)?;
620-
self.finish_emit(emit_to, emit_groups);
615+
616+
for (batch, entries) in self.batches.iter().zip(&self.batch_entries) {
617+
Self::append_batch_values(
618+
&mut to_emit,
619+
entries,
620+
batch,
621+
&self.delimiter,
622+
emit_groups,
623+
)?;
624+
}
625+
626+
match emit_to {
627+
EmitTo::All => self.clear_state(),
628+
EmitTo::First(_) => {
629+
self.retain_after_emit(emit_groups);
630+
self.num_groups = self.values.len();
631+
}
632+
}
621633

622634
Ok(Arc::new(LargeStringArray::from(to_emit)))
623635
}

0 commit comments

Comments
 (0)