Skip to content

Commit 3a23bb2

Browse files
authored
perf: Optimize array_agg() using GroupsAccumulator (#20504)
## Which issue does this PR close? - Closes #20465. - Closes #17446. ## Rationale for this change This PR optimizes the performance of `array_agg()` by adding support for the `GroupsAccumulator` API. The design tries to minimize the amount of per-batch work done in `update_batch()`: we store a reference to the batch, and a `(group_idx, row_idx)` pair for each row. In `evaluate()`, we assemble all the requested output with a single `interleave` call. This turns out to be significantly faster, because we copy much less data and assembling the results can be vectorized more effectively. For example, on a benchmark with 5000 groups and 5000 int64 values per group, this approach is roughly 190x faster than the previous approach. Releasing memory after a partial emit is a little more involved than the previous approach, but with some determination it is still possible. ## What changes are included in this PR? * Implement the `GroupsAccumulator` API for `array_agg()` * Add benchmark for `array_agg` of a named struct over a dict, following the workload in #17446 * Add unit tests * Improve SLT test coverage * Remove a redundant SLT test ## Are these changes tested? Yes, and benchmarked. ## Are there any user-facing changes? No. ## AI usage Iterated with the help of multiple AI tools; I've reviewed and understand the resulting code.
1 parent 73fbd48 commit 3a23bb2

5 files changed

Lines changed: 773 additions & 17 deletions

File tree

datafusion/core/benches/aggregate_query_sql.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@ fn criterion_benchmark(c: &mut Criterion) {
284284
)
285285
})
286286
});
287+
288+
c.bench_function("array_agg_struct_query_group_by_mid_groups", |b| {
289+
b.iter(|| {
290+
query(
291+
ctx.clone(),
292+
&rt,
293+
"SELECT u64_mid, array_agg(named_struct('market', dict10, 'price', f64)) \
294+
FROM t GROUP BY u64_mid",
295+
)
296+
})
297+
});
287298
}
288299

289300
criterion_group!(benches, criterion_benchmark);

datafusion/core/benches/data_utils/mod.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
use arrow::array::{
2121
ArrayRef, Float32Array, Float64Array, RecordBatch, StringArray, StringViewBuilder,
2222
UInt64Array,
23-
builder::{Int64Builder, StringBuilder},
23+
builder::{Int64Builder, StringBuilder, StringDictionaryBuilder},
2424
};
25+
use arrow::datatypes::Int32Type;
2526
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2627
use datafusion::datasource::MemTable;
2728
use datafusion::error::Result;
@@ -65,6 +66,11 @@ pub fn create_schema() -> Schema {
6566
// Integers randomly selected from a narrow range of values such that
6667
// there are a few distinct values, but they are repeated often.
6768
Field::new("u64_narrow", DataType::UInt64, false),
69+
Field::new(
70+
"dict10",
71+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
72+
true,
73+
),
6874
])
6975
}
7076

@@ -109,6 +115,15 @@ fn create_record_batch(
109115
.map(|_| rng.random_range(0..10))
110116
.collect::<Vec<_>>();
111117

118+
let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
119+
for _ in 0..batch_size {
120+
if rng.random::<f64>() > 0.9 {
121+
dict_builder.append_null();
122+
} else {
123+
dict_builder.append_value(format!("market_{}", rng.random_range(0..10)));
124+
}
125+
}
126+
112127
RecordBatch::try_new(
113128
schema,
114129
vec![
@@ -118,6 +133,7 @@ fn create_record_batch(
118133
Arc::new(UInt64Array::from(integer_values_wide)),
119134
Arc::new(UInt64Array::from(integer_values_mid)),
120135
Arc::new(UInt64Array::from(integer_values_narrow)),
136+
Arc::new(dict_builder.finish()),
121137
],
122138
)
123139
.unwrap()

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
4444
/// The `NullBuffer` is
4545
/// * `true` (representing valid) for values that were `true` in filter
4646
/// * `false` (representing null) for values that were `false` or `null` in filter
47-
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
47+
pub fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
4848
let (filter_bools, filter_nulls) = filter.clone().into_parts();
4949
let filter_bools = NullBuffer::from(filter_bools);
5050
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())

0 commit comments

Comments
 (0)