Skip to content

Commit 5f8b131

Browse files
authored
perf: Implement groups accumulator count distinct primitive types (#21561)
## Which issue does this PR close? Evaluate perf with group accumulators for count distinct <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. Related : #21575 benchmark PR ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent cfafce4 commit 5f8b131

3 files changed

Lines changed: 254 additions & 10 deletions

File tree

datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
mod bytes;
1919
mod dict;
20+
mod groups;
2021
mod native;
2122

2223
pub use bytes::BytesDistinctCountAccumulator;
2324
pub use bytes::BytesViewDistinctCountAccumulator;
2425
pub use dict::DictionaryCountAccumulator;
26+
pub use groups::PrimitiveDistinctCountGroupsAccumulator;
2527
pub use native::Bitmap65536DistinctCountAccumulator;
2628
pub use native::Bitmap65536DistinctCountAccumulatorI16;
2729
pub use native::BoolArray256DistinctCountAccumulator;
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray,
20+
};
21+
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
22+
use arrow::datatypes::{ArrowPrimitiveType, Field};
23+
use datafusion_common::HashSet;
24+
use datafusion_common::hash_utils::RandomState;
25+
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
26+
use std::hash::Hash;
27+
use std::mem::size_of;
28+
use std::sync::Arc;
29+
30+
use crate::aggregate::groups_accumulator::accumulate::accumulate;
31+
32+
pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
33+
where
34+
T::Native: Eq + Hash,
35+
{
36+
seen: HashSet<(usize, T::Native), RandomState>,
37+
counts: Vec<i64>,
38+
}
39+
40+
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
41+
where
42+
T::Native: Eq + Hash,
43+
{
44+
pub fn new() -> Self {
45+
Self {
46+
seen: HashSet::default(),
47+
counts: Vec::new(),
48+
}
49+
}
50+
}
51+
52+
impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
53+
where
54+
T::Native: Eq + Hash,
55+
{
56+
fn default() -> Self {
57+
Self::new()
58+
}
59+
}
60+
61+
impl<T: ArrowPrimitiveType + Send + std::fmt::Debug> GroupsAccumulator
62+
for PrimitiveDistinctCountGroupsAccumulator<T>
63+
where
64+
T::Native: Eq + Hash,
65+
{
66+
fn update_batch(
67+
&mut self,
68+
values: &[ArrayRef],
69+
group_indices: &[usize],
70+
opt_filter: Option<&BooleanArray>,
71+
total_num_groups: usize,
72+
) -> datafusion_common::Result<()> {
73+
debug_assert_eq!(values.len(), 1);
74+
self.counts.resize(total_num_groups, 0);
75+
let arr = values[0].as_primitive::<T>();
76+
accumulate(group_indices, arr, opt_filter, |group_idx, value| {
77+
if self.seen.insert((group_idx, value)) {
78+
self.counts[group_idx] += 1;
79+
}
80+
});
81+
Ok(())
82+
}
83+
84+
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
85+
let counts = emit_to.take_needed(&mut self.counts);
86+
87+
match emit_to {
88+
EmitTo::All => {
89+
self.seen.clear();
90+
}
91+
EmitTo::First(n) => {
92+
let mut remaining = HashSet::default();
93+
for (group_idx, value) in self.seen.drain() {
94+
if group_idx >= n {
95+
remaining.insert((group_idx - n, value));
96+
}
97+
}
98+
self.seen = remaining;
99+
}
100+
}
101+
102+
Ok(Arc::new(Int64Array::from(counts)))
103+
}
104+
105+
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
106+
let num_emitted = match emit_to {
107+
EmitTo::All => self.counts.len(),
108+
EmitTo::First(n) => n,
109+
};
110+
111+
// Prefix-sum counts[..num_emitted] into offsets
112+
let mut offsets = Vec::with_capacity(num_emitted + 1);
113+
offsets.push(0i32);
114+
let mut total = 0i32;
115+
for &c in &self.counts[..num_emitted] {
116+
total += c as i32;
117+
offsets.push(total);
118+
}
119+
120+
let mut all_values = vec![T::Native::default(); total as usize];
121+
let mut cursors: Vec<i32> = offsets[..num_emitted].to_vec();
122+
123+
if matches!(emit_to, EmitTo::All) {
124+
for (group_idx, value) in self.seen.drain() {
125+
let pos = cursors[group_idx] as usize;
126+
all_values[pos] = value;
127+
cursors[group_idx] += 1;
128+
}
129+
self.counts.clear();
130+
} else {
131+
let mut remaining = HashSet::default();
132+
for (group_idx, value) in self.seen.drain() {
133+
if group_idx < num_emitted {
134+
let pos = cursors[group_idx] as usize;
135+
all_values[pos] = value;
136+
cursors[group_idx] += 1;
137+
} else {
138+
remaining.insert((group_idx - num_emitted, value));
139+
}
140+
}
141+
self.seen = remaining;
142+
let _ = emit_to.take_needed(&mut self.counts);
143+
}
144+
145+
let values_array = Arc::new(PrimitiveArray::<T>::new(
146+
ScalarBuffer::from(all_values),
147+
None,
148+
));
149+
let list_array = ListArray::new(
150+
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
151+
OffsetBuffer::new(offsets.into()),
152+
values_array,
153+
None,
154+
);
155+
156+
Ok(vec![Arc::new(list_array)])
157+
}
158+
159+
fn merge_batch(
160+
&mut self,
161+
values: &[ArrayRef],
162+
group_indices: &[usize],
163+
_opt_filter: Option<&BooleanArray>,
164+
total_num_groups: usize,
165+
) -> datafusion_common::Result<()> {
166+
debug_assert_eq!(values.len(), 1);
167+
self.counts.resize(total_num_groups, 0);
168+
let list_array = values[0].as_list::<i32>();
169+
let inner = list_array.values().as_primitive::<T>();
170+
let inner_values = inner.values();
171+
let offsets = list_array.offsets();
172+
173+
for (row_idx, &group_idx) in group_indices.iter().enumerate() {
174+
let start = offsets[row_idx] as usize;
175+
let end = offsets[row_idx + 1] as usize;
176+
for &value in &inner_values[start..end] {
177+
if self.seen.insert((group_idx, value)) {
178+
self.counts[group_idx] += 1;
179+
}
180+
}
181+
}
182+
183+
Ok(())
184+
}
185+
186+
fn size(&self) -> usize {
187+
size_of::<Self>()
188+
+ self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
189+
+ self.counts.capacity() * size_of::<i64>()
190+
}
191+
}

datafusion/functions-aggregate/src/count.rs

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ use arrow::{
2121
compute,
2222
datatypes::{
2323
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
24-
FieldRef, Float16Type, Float32Type, Float64Type, Int32Type, Int64Type,
25-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
24+
FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
25+
Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
2626
Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
2727
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28-
UInt32Type, UInt64Type,
28+
UInt8Type, UInt16Type, UInt32Type, UInt64Type,
2929
},
3030
};
3131
use datafusion_common::hash_utils::RandomState;
@@ -41,6 +41,7 @@ use datafusion_expr::{
4141
function::{AccumulatorArgs, StateFieldsArgs},
4242
utils::format_state_name,
4343
};
44+
use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
4445
use datafusion_functions_aggregate_common::aggregate::{
4546
count_distinct::Bitmap65536DistinctCountAccumulator,
4647
count_distinct::Bitmap65536DistinctCountAccumulatorI16,
@@ -344,20 +345,33 @@ impl AggregateUDFImpl for Count {
344345
}
345346

346347
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
347-
// groups accumulator only supports `COUNT(c1)`, not
348-
// `COUNT(c1, c2)`, etc
349-
if args.is_distinct {
348+
if args.exprs.len() != 1 {
350349
return false;
351350
}
352-
args.exprs.len() == 1
351+
if !args.is_distinct {
352+
return true;
353+
}
354+
matches!(
355+
args.expr_fields[0].data_type(),
356+
DataType::Int8
357+
| DataType::Int16
358+
| DataType::Int32
359+
| DataType::Int64
360+
| DataType::UInt8
361+
| DataType::UInt16
362+
| DataType::UInt32
363+
| DataType::UInt64
364+
)
353365
}
354366

355367
fn create_groups_accumulator(
356368
&self,
357-
_args: AccumulatorArgs,
369+
args: AccumulatorArgs,
358370
) -> Result<Box<dyn GroupsAccumulator>> {
359-
// instantiate specialized accumulator
360-
Ok(Box::new(CountGroupsAccumulator::new()))
371+
if !args.is_distinct {
372+
return Ok(Box::new(CountGroupsAccumulator::new()));
373+
}
374+
create_distinct_count_groups_accumulator(&args)
361375
}
362376

363377
fn reverse_expr(&self) -> ReversedUDAF {
@@ -430,6 +444,43 @@ impl AggregateUDFImpl for Count {
430444
}
431445
}
432446

447+
#[cold]
448+
fn create_distinct_count_groups_accumulator(
449+
args: &AccumulatorArgs,
450+
) -> Result<Box<dyn GroupsAccumulator>> {
451+
let data_type = args.expr_fields[0].data_type();
452+
match data_type {
453+
DataType::Int8 => Ok(Box::new(
454+
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
455+
)),
456+
DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
457+
Int16Type,
458+
>::new())),
459+
DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
460+
Int32Type,
461+
>::new())),
462+
DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
463+
Int64Type,
464+
>::new())),
465+
DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
466+
UInt8Type,
467+
>::new())),
468+
DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
469+
UInt16Type,
470+
>::new())),
471+
DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
472+
UInt32Type,
473+
>::new())),
474+
DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
475+
UInt64Type,
476+
>::new())),
477+
_ => not_impl_err!(
478+
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
479+
data_type
480+
),
481+
}
482+
}
483+
433484
// DistinctCountAccumulator does not support retract_batch and sliding window
434485
// this is a specialized accumulator for distinct count that supports retract_batch
435486
// and sliding window.

0 commit comments

Comments
 (0)