Skip to content

Commit 2cf899e

Browse files
committed
implement_group_accumulators_count_distinct
1 parent a85a513 commit 2cf899e

3 files changed

Lines changed: 125 additions & 36 deletions

File tree

datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
mod bytes;
1919
mod dict;
20-
mod native;
2120
mod groups;
21+
mod native;
2222

2323
pub use bytes::BytesDistinctCountAccumulator;
2424
pub use bytes::BytesViewDistinctCountAccumulator;
2525
pub use dict::DictionaryCountAccumulator;
26+
pub use groups::PrimitiveDistinctCountGroupsAccumulator;
2627
pub use native::FloatDistinctCountAccumulator;
2728
pub use native::PrimitiveDistinctCountAccumulator;
28-
pub use groups::PrimitiveDistinctCountGroupsAccumulator;
Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,121 @@
1-
use std::hash::Hash;
2-
use arrow::array::{ArrayRef, BooleanArray, Int64Array};
3-
use arrow::datatypes::{ArrowPrimitiveType, DataType};
1+
use arrow::array::{
2+
Array, ArrayRef, AsArray, BooleanArray, Int64Array, ListBuilder, PrimitiveBuilder,
3+
};
4+
use arrow::datatypes::ArrowPrimitiveType;
5+
use datafusion_common::HashSet;
6+
use datafusion_common::hash_utils::RandomState;
47
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
8+
use std::hash::Hash;
9+
use std::sync::Arc;
510

611
pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
7-
where
12+
where
813
T::Native: Eq + Hash,
9-
{
14+
{
1015
/// Count distinct per group.
11-
values: Vec<T>,
12-
data_type : DataType
16+
values: Vec<HashSet<T::Native, RandomState>>,
17+
}
18+
19+
impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
20+
where
21+
T::Native: Eq + Hash,
22+
{
23+
fn default() -> Self {
24+
Self::new()
25+
}
1326
}
1427

1528
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
1629
where
1730
T::Native: Eq + Hash,
1831
{
19-
pub fn new(data_type: DataType) -> Self {
20-
Self {
21-
values: Vec::new(),
22-
data_type,
23-
}
32+
pub fn new() -> Self {
33+
Self { values: Vec::new() }
2434
}
2535
}
2636

2737
impl<T: ArrowPrimitiveType + Send + std::fmt::Debug> GroupsAccumulator
28-
for PrimitiveDistinctCountGroupsAccumulator<T>
38+
for PrimitiveDistinctCountGroupsAccumulator<T>
2939
where
3040
T::Native: Eq + Hash,
31-
{
32-
fn update_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion_common::Result<()> {
33-
todo!()
41+
{
42+
fn update_batch(
43+
&mut self,
44+
values: &[ArrayRef],
45+
group_indices: &[usize],
46+
opt_filter: Option<&BooleanArray>,
47+
total_num_groups: usize,
48+
) -> datafusion_common::Result<()> {
49+
self.values.resize_with(total_num_groups, HashSet::default);
50+
debug_assert_eq!(values.len(), 1, "multiple arguments are not supported");
51+
52+
let arr = values[0].as_primitive::<T>();
53+
54+
for (idx, group_idx) in group_indices.iter().enumerate() {
55+
if let Some(filter) = opt_filter
56+
&& !filter.value(idx)
57+
{
58+
continue;
59+
}
60+
if arr.is_valid(idx) {
61+
let value = arr.value(idx);
62+
self.values[*group_idx].insert(value);
63+
}
64+
}
65+
66+
Ok(())
3467
}
3568

3669
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
37-
todo!()
70+
let counts: Vec<i64> = emit_to
71+
.take_needed(&mut self.values)
72+
.iter()
73+
.map(|groups| groups.len() as i64)
74+
.collect();
75+
76+
Ok(Arc::new(Int64Array::from(counts)))
3877
}
3978

4079
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
41-
todo!()
80+
let hash_sets = emit_to.take_needed(&mut self.values);
81+
let mut builder = ListBuilder::new(PrimitiveBuilder::<T>::new());
82+
83+
for set in hash_sets {
84+
for value in set {
85+
builder.values().append_value(value);
86+
}
87+
builder.append(true);
88+
}
89+
Ok(vec![Arc::new(builder.finish())])
4290
}
4391

44-
fn merge_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion_common::Result<()> {
45-
todo!()
92+
fn merge_batch(
93+
&mut self,
94+
values: &[ArrayRef],
95+
group_indices: &[usize],
96+
_opt_filter: Option<&BooleanArray>,
97+
total_num_groups: usize,
98+
) -> datafusion_common::Result<()> {
99+
self.values.resize_with(total_num_groups, HashSet::default);
100+
let list_array = values[0].as_list::<i32>();
101+
102+
for (row_idx, group_idx) in group_indices.iter().enumerate() {
103+
let inner = list_array.value(row_idx);
104+
let inner_set = inner.as_primitive::<T>();
105+
for i in 0..inner.len() {
106+
self.values[*group_idx].insert(inner_set.value(i));
107+
}
108+
}
109+
Ok(())
46110
}
47111

48112
fn size(&self) -> usize {
49-
todo!()
113+
size_of::<Self>()
114+
+ self.values.capacity() * size_of::<HashSet<T::Native, RandomState>>()
115+
+ self
116+
.values
117+
.iter()
118+
.map(|s| s.capacity() * size_of::<T::Native>())
119+
.sum::<usize>()
50120
}
51121
}

datafusion/functions-aggregate/src/count.rs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use datafusion_expr::{
4141
function::{AccumulatorArgs, StateFieldsArgs},
4242
utils::format_state_name,
4343
};
44+
use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
4445
use datafusion_functions_aggregate_common::aggregate::{
4546
count_distinct::BytesDistinctCountAccumulator,
4647
count_distinct::BytesViewDistinctCountAccumulator,
@@ -59,7 +60,6 @@ use std::{
5960
ops::BitAnd,
6061
sync::Arc,
6162
};
62-
use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
6363

6464
make_udaf_expr_and_func!(
6565
Count,
@@ -344,19 +344,38 @@ impl AggregateUDFImpl for Count {
344344
&self,
345345
args: AccumulatorArgs,
346346
) -> Result<Box<dyn GroupsAccumulator>> {
347-
if(args.is_distinct){
347+
if args.is_distinct {
348348
let data_type = args.expr_fields[0].data_type();
349-
return match data_type{
350-
DataType::Int8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(data_type.clone()))),
351-
DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(data_type.clone()))),
352-
DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(data_type.clone()))),
353-
DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(data_type.clone()))),
354-
DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(data_type.clone()))),
355-
DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt16Type>::new(data_type.clone()))),
356-
DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt32Type>::new(data_type.clone()))),
357-
DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt64Type>::new(data_type.clone()))),
358-
_ => not_impl_err!("GroupsAccumulator not supported for COUNT(DISTINCT) with {}", data_type),
359-
}
349+
return match data_type {
350+
DataType::Int8 => Ok(Box::new(
351+
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
352+
)),
353+
DataType::Int16 => Ok(Box::new(
354+
PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(),
355+
)),
356+
DataType::Int32 => Ok(Box::new(
357+
PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(),
358+
)),
359+
DataType::Int64 => Ok(Box::new(
360+
PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(),
361+
)),
362+
DataType::UInt8 => Ok(Box::new(
363+
PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(),
364+
)),
365+
DataType::UInt16 => {
366+
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt16Type>::new()))
367+
}
368+
DataType::UInt32 => {
369+
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt32Type>::new()))
370+
}
371+
DataType::UInt64 => {
372+
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt64Type>::new()))
373+
}
374+
_ => not_impl_err!(
375+
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
376+
data_type
377+
),
378+
};
360379
}
361380
// instantiate specialized accumulator
362381
Ok(Box::new(CountGroupsAccumulator::new()))

0 commit comments

Comments
 (0)