diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index bcccea381324e..220148a080fbb 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -757,11 +757,9 @@ where /// The type of the returned sum return_data_type: DataType, - /// Count per group (use u64 to make UInt64Array) - counts: Vec, - - /// Sums per group, stored as the native type - sums: Vec, + /// Combined count and sum per group in a single Vec to halve reallocation cost. + /// Each entry stores (count, sum) for one group. + states: Vec>, /// Track nulls in the input / filters null_state: NullState, @@ -770,6 +768,14 @@ where avg_fn: F, } +/// Combined per-group state for AVG accumulator. +/// Stored in a single Vec to reduce reallocation overhead. +#[derive(Debug, Clone, Copy)] +struct AvgState { + count: u64, + sum: N, +} + impl AvgGroupsAccumulator where T: ArrowNumericType + Send, @@ -784,8 +790,7 @@ where Self { return_data_type: return_data_type.clone(), sum_data_type: sum_data_type.clone(), - counts: vec![], - sums: vec![], + states: vec![], null_state: NullState::new(), avg_fn, } @@ -808,8 +813,13 @@ where let values = values[0].as_primitive::(); // increment counts, update sums - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); + self.states.resize( + total_num_groups, + AvgState { + count: 0, + sum: T::default_value(), + }, + ); self.null_state.accumulate( group_indices, values, @@ -817,10 +827,9 @@ where total_num_groups, |group_index, new_value| { // SAFETY: group_index is guaranteed to be in bounds - let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; - *sum = sum.add_wrapping(new_value); - - self.counts[group_index] += 1; + let state = unsafe { self.states.get_unchecked_mut(group_index) }; + state.sum = state.sum.add_wrapping(new_value); + state.count += 1; }, ); @@ -828,14 +837,12 @@ where } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - let sums = emit_to.take_needed(&mut self.sums); + let states = emit_to.take_needed(&mut self.states); let nulls = self.null_state.build(emit_to); if let Some(nulls) = &nulls { - assert_eq!(nulls.len(), sums.len()); + assert_eq!(nulls.len(), states.len()); } - assert_eq!(counts.len(), sums.len()); // don't evaluate averages with null inputs to avoid errors on null values @@ -844,21 +851,20 @@ where { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()) .with_data_type(self.return_data_type.clone()); - let iter = sums.into_iter().zip(counts).zip(nulls.iter()); + let iter = states.into_iter().zip(nulls.iter()); - for ((sum, count), is_valid) in iter { + for (state, is_valid) in iter { if is_valid { - builder.append_value((self.avg_fn)(sum, count)?) + builder.append_value((self.avg_fn)(state.sum, state.count)?) } else { builder.append_null(); } } builder.finish() } else { - let averages: Vec = sums + let averages: Vec = states .into_iter() - .zip(counts.into_iter()) - .map(|(sum, count)| (self.avg_fn)(sum, count)) + .map(|state| (self.avg_fn)(state.sum, state.count)) .collect::>>()?; PrimitiveArray::new(averages.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) @@ -871,11 +877,11 @@ where fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let counts = emit_to.take_needed(&mut self.counts); - let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy - - let sums = emit_to.take_needed(&mut self.sums); - let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy + let states = emit_to.take_needed(&mut self.states); + let (counts, sums): (Vec, Vec) = + states.into_iter().map(|s| (s.count, s.sum)).unzip(); + let counts = UInt64Array::new(counts.into(), nulls.clone()); + let sums = PrimitiveArray::::new(sums.into(), nulls) .with_data_type(self.sum_data_type.clone()); Ok(vec![ @@ -895,8 +901,14 @@ where // first batch is counts, second is partial sums let partial_counts = values[0].as_primitive::(); let partial_sums = values[1].as_primitive::(); - // update counts with partial counts - self.counts.resize(total_num_groups, 0); + // single resize for combined state + self.states.resize( + total_num_groups, + AvgState { + count: 0, + sum: T::default_value(), + }, + ); self.null_state.accumulate( group_indices, partial_counts, @@ -904,13 +916,10 @@ where total_num_groups, |group_index, partial_count| { // SAFETY: group_index is guaranteed to be in bounds - let count = unsafe { self.counts.get_unchecked_mut(group_index) }; - *count += partial_count; + let state = unsafe { self.states.get_unchecked_mut(group_index) }; + state.count += partial_count; }, ); - - // update sums - self.sums.resize(total_num_groups, T::default_value()); self.null_state.accumulate( group_indices, partial_sums, @@ -918,8 +927,8 @@ where total_num_groups, |group_index, new_value: ::Native| { // SAFETY: group_index is guaranteed to be in bounds - let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; - *sum = sum.add_wrapping(new_value); + let state = unsafe { self.states.get_unchecked_mut(group_index) }; + state.sum = state.sum.add_wrapping(new_value); }, ); @@ -951,6 +960,6 @@ where } fn size(&self) -> usize { - self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() + self.states.capacity() * size_of::>() } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs index a91dd3115d879..f0a83a3a4dd12 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs @@ -157,6 +157,8 @@ impl ByteViewGroupValueBuilder { Nulls::Some }; + self.views.reserve(rows.len()); + match all_null_or_non_null { Nulls::Some => { for &row in rows { diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index 31126348b3fd4..cf02504fd709c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -203,6 +203,8 @@ impl GroupColumn Nulls::Some }; + self.group_values.reserve(rows.len()); + match (NULLABLE, all_null_or_non_null) { (true, Nulls::Some) => { for &row in rows {