Skip to content

Commit ca1d39d

Browse files
azhangdalamb
andauthored
perf: implement convert_to_state for SparkAvg (#21548)
## Which issue does this PR close? - Part of #17964. ## Rationale for this change SparkAvg's AvgGroupsAccumulator doesn't implement supports_convert_to_state (defaults to false), which prevents the skip-partial-aggregation optimization from kicking in for queries that use Spark's avg(). I ran into this while benchmarking a Spark Connect engine built on DataFusion. On TPC-H q17 at SF10, the partial aggregate for avg(l_quantity) grouped by l_partkey (~2M groups out of 60M rows) was not triggering skip-aggregation: | Metric | Without convert_to_state | With convert_to_state | |--------|-------------------------|-----------------------| | Partial aggregate memory | 923 MB | 40 MB | | Partial aggregate elapsed | 4.75s | 109ms | The skip-aggregation probe (#11627) detects when a partial aggregate isn't reducing cardinality and falls back to passing rows through as state directly. This needs convert_to_state so the accumulator can produce [sum, count] state arrays from raw input. The built-in Avg already has this (#11734), but it wasn't carried over when SparkAvg was migrated from Comet in #17871. ## What changes are included in this PR? Adds convert_to_state() and supports_convert_to_state() to AvgGroupsAccumulator in datafusion-spark. Follows the same approach as the built-in Avg, adapted for SparkAvg's differences: - State order is [sum, count] (vs [count, sum] in the built-in) - Count type is Int64 (vs UInt64 in the built-in) - Null handling uses NullBuffer::union directly instead of pulling in datafusion-functions-aggregate-common as a dep Also cleaned up the fully-qualified arrow::array::BooleanArray references in update_batch / merge_batch since adding BooleanArray to the import block triggered the unused_qualifications lint. ## Are these changes tested? Yes, unit tests covering basic conversion, null propagation, filter handling, and a roundtrip through merge_batch to verify the converted state produces correct results end-to-end. ## Are there any user-facing changes? No. Queries using avg() through the Spark function registry will automatically benefit from skip-partial-aggregation on high-cardinality groupings. --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 40b209e commit ca1d39d

3 files changed

Lines changed: 161 additions & 13 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/spark/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ datafusion-execution = { workspace = true }
5555
datafusion-expr = { workspace = true }
5656
datafusion-functions = { workspace = true, features = ["crypto_expressions"] }
5757
datafusion-functions-aggregate = { workspace = true }
58+
datafusion-functions-aggregate-common = { workspace = true }
5859
datafusion-functions-nested = { workspace = true }
5960
log = { workspace = true }
6061
num-traits = { workspace = true }

datafusion/spark/src/function/aggregate/avg.rs

Lines changed: 159 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
// under the License.
1717

1818
use arrow::array::{
19-
Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray,
19+
Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, BooleanArray, Int64Array,
20+
PrimitiveArray,
2021
builder::PrimitiveBuilder,
2122
cast::AsArray,
2223
types::{Float64Type, Int64Type},
@@ -31,6 +32,9 @@ use datafusion_expr::{
3132
Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF,
3233
Signature, TypeSignatureClass, Volatility,
3334
};
35+
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
36+
filtered_null_mask, set_nulls,
37+
};
3438
use std::sync::Arc;
3539

3640
/// AVG aggregate expression
@@ -248,7 +252,7 @@ where
248252
&mut self,
249253
values: &[ArrayRef],
250254
group_indices: &[usize],
251-
_opt_filter: Option<&arrow::array::BooleanArray>,
255+
_opt_filter: Option<&BooleanArray>,
252256
total_num_groups: usize,
253257
) -> Result<()> {
254258
assert_eq!(values.len(), 1, "single argument to update_batch");
@@ -285,26 +289,26 @@ where
285289
&mut self,
286290
values: &[ArrayRef],
287291
group_indices: &[usize],
288-
_opt_filter: Option<&arrow::array::BooleanArray>,
292+
_opt_filter: Option<&BooleanArray>,
289293
total_num_groups: usize,
290294
) -> Result<()> {
291295
assert_eq!(values.len(), 2, "two arguments to merge_batch");
292296
// first batch is partial sums, second is counts
293297
let partial_sums = values[0].as_primitive::<T>();
294298
let partial_counts = values[1].as_primitive::<Int64Type>();
295-
// update counts with partial counts
296-
self.counts.resize(total_num_groups, 0);
297-
let iter1 = group_indices.iter().zip(partial_counts.values().iter());
298-
for (&group_index, &partial_count) in iter1 {
299-
self.counts[group_index] += partial_count;
300-
}
301299

302-
// update sums
300+
self.counts.resize(total_num_groups, 0);
303301
self.sums.resize(total_num_groups, T::default_value());
304-
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
305-
for (&group_index, &new_value) in iter2 {
302+
303+
for (idx, &group_index) in group_indices.iter().enumerate() {
304+
// Skip null state entries emitted by convert_to_state for
305+
// filtered / null input rows.
306+
if partial_counts.is_null(idx) || partial_sums.is_null(idx) {
307+
continue;
308+
}
309+
self.counts[group_index] += partial_counts.value(idx);
306310
let sum = &mut self.sums[group_index];
307-
*sum = sum.add_wrapping(new_value);
311+
*sum = sum.add_wrapping(partial_sums.value(idx));
308312
}
309313

310314
Ok(())
@@ -343,7 +347,149 @@ where
343347
])
344348
}
345349

350+
fn convert_to_state(
351+
&self,
352+
values: &[ArrayRef],
353+
opt_filter: Option<&BooleanArray>,
354+
) -> Result<Vec<ArrayRef>> {
355+
let sums = values[0]
356+
.as_primitive::<T>()
357+
.clone()
358+
.with_data_type(self.return_data_type.clone());
359+
let counts = Int64Array::from_value(1, sums.len());
360+
361+
let nulls = filtered_null_mask(opt_filter, &sums);
362+
let counts = set_nulls(counts, nulls.clone());
363+
let sums = set_nulls(sums, nulls);
364+
365+
// [sum, count] - must match state() and merge_batch()
366+
Ok(vec![
367+
Arc::new(sums) as ArrayRef,
368+
Arc::new(counts) as ArrayRef,
369+
])
370+
}
371+
372+
fn supports_convert_to_state(&self) -> bool {
373+
true
374+
}
375+
346376
fn size(&self) -> usize {
347377
self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()
348378
}
349379
}
380+
381+
#[cfg(test)]
382+
mod tests {
383+
use super::*;
384+
use arrow::array::Float64Array;
385+
386+
fn make_acc() -> AvgGroupsAccumulator<Float64Type, impl Fn(f64, i64) -> Result<f64>> {
387+
AvgGroupsAccumulator::<Float64Type, _>::new(&DataType::Float64, |sum, count| {
388+
Ok(sum / count as f64)
389+
})
390+
}
391+
392+
#[test]
393+
fn supports_convert_to_state() {
394+
assert!(make_acc().supports_convert_to_state());
395+
}
396+
397+
#[test]
398+
fn convert_to_state_basic() {
399+
let acc = make_acc();
400+
let values: Vec<ArrayRef> =
401+
vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))];
402+
let state = acc.convert_to_state(&values, None).unwrap();
403+
404+
assert_eq!(state.len(), 2);
405+
let sums = state[0].as_primitive::<Float64Type>();
406+
let counts = state[1].as_primitive::<Int64Type>();
407+
408+
assert_eq!(sums.values().as_ref(), &[1.0, 2.0, 3.0]);
409+
assert_eq!(counts.values().as_ref(), &[1, 1, 1]);
410+
assert_eq!(sums.null_count(), 0);
411+
assert_eq!(counts.null_count(), 0);
412+
}
413+
414+
#[test]
415+
fn convert_to_state_with_nulls() {
416+
let acc = make_acc();
417+
let values: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
418+
Some(1.0),
419+
None,
420+
Some(3.0),
421+
]))];
422+
let state = acc.convert_to_state(&values, None).unwrap();
423+
424+
let sums = state[0].as_primitive::<Float64Type>();
425+
let counts = state[1].as_primitive::<Int64Type>();
426+
427+
assert!(!sums.is_null(0));
428+
assert!(sums.is_null(1));
429+
assert!(!sums.is_null(2));
430+
431+
assert_eq!(counts.value(0), 1);
432+
assert!(counts.is_null(1));
433+
assert_eq!(counts.value(2), 1);
434+
}
435+
436+
#[test]
437+
fn convert_to_state_with_filter() {
438+
let acc = make_acc();
439+
let values: Vec<ArrayRef> =
440+
vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))];
441+
let filter = BooleanArray::from(vec![true, false, true]);
442+
let state = acc.convert_to_state(&values, Some(&filter)).unwrap();
443+
444+
let sums = state[0].as_primitive::<Float64Type>();
445+
let counts = state[1].as_primitive::<Int64Type>();
446+
447+
assert!(!sums.is_null(0));
448+
assert!(sums.is_null(1));
449+
assert!(!sums.is_null(2));
450+
451+
assert_eq!(counts.value(0), 1);
452+
assert!(counts.is_null(1));
453+
assert_eq!(counts.value(2), 1);
454+
}
455+
456+
#[test]
457+
fn convert_to_state_roundtrips_through_merge() {
458+
let mut acc = make_acc();
459+
let input: Vec<ArrayRef> =
460+
vec![Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0]))];
461+
let state = acc.convert_to_state(&input, None).unwrap();
462+
463+
// feed the converted state back through merge_batch
464+
acc.merge_batch(
465+
&state,
466+
&[0, 0, 0],
467+
None,
468+
1, // single group
469+
)
470+
.unwrap();
471+
472+
let result = acc.evaluate(EmitTo::All).unwrap();
473+
let result = result.as_primitive::<Float64Type>();
474+
assert_eq!(result.value(0), 20.0); // (10+20+30)/3
475+
}
476+
477+
#[test]
478+
fn convert_to_state_null_merge_matches_direct() {
479+
// avg([1.0, NULL, 3.0]) must be 2.0 after a convert_to_state → merge_batch
480+
// round-trip. Before the merge-path null fix this leaked the backing
481+
// buffer value at the null slot and produced the wrong average.
482+
let mut acc = make_acc();
483+
let input: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
484+
Some(1.0),
485+
None,
486+
Some(3.0),
487+
]))];
488+
let state = acc.convert_to_state(&input, None).unwrap();
489+
acc.merge_batch(&state, &[0, 0, 0], None, 1).unwrap();
490+
491+
let result = acc.evaluate(EmitTo::All).unwrap();
492+
let result = result.as_primitive::<Float64Type>();
493+
assert_eq!(result.value(0), 2.0);
494+
}
495+
}

0 commit comments

Comments
 (0)