Skip to content

Commit a85a513

Browse files
committed
init
1 parent fbdf770 commit a85a513

3 files changed

Lines changed: 69 additions & 6 deletions

File tree

datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
mod bytes;
1919
mod dict;
2020
mod native;
21+
mod groups;
2122

2223
pub use bytes::BytesDistinctCountAccumulator;
2324
pub use bytes::BytesViewDistinctCountAccumulator;
2425
pub use dict::DictionaryCountAccumulator;
2526
pub use native::FloatDistinctCountAccumulator;
2627
pub use native::PrimitiveDistinctCountAccumulator;
28+
pub use groups::PrimitiveDistinctCountGroupsAccumulator;
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use std::hash::Hash;
2+
use arrow::array::{ArrayRef, BooleanArray, Int64Array};
3+
use arrow::datatypes::{ArrowPrimitiveType, DataType};
4+
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
5+
6+
pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
7+
where
8+
T::Native: Eq + Hash,
9+
{
10+
/// Count distinct per group.
11+
values: Vec<T>,
12+
data_type : DataType
13+
}
14+
15+
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
16+
where
17+
T::Native: Eq + Hash,
18+
{
19+
pub fn new(data_type: DataType) -> Self {
20+
Self {
21+
values: Vec::new(),
22+
data_type,
23+
}
24+
}
25+
}
26+
27+
impl<T: ArrowPrimitiveType + Send + std::fmt::Debug> GroupsAccumulator
28+
for PrimitiveDistinctCountGroupsAccumulator<T>
29+
where
30+
T::Native: Eq + Hash,
31+
{
32+
fn update_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion_common::Result<()> {
33+
todo!()
34+
}
35+
36+
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
37+
todo!()
38+
}
39+
40+
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
41+
todo!()
42+
}
43+
44+
fn merge_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion_common::Result<()> {
45+
todo!()
46+
}
47+
48+
fn size(&self) -> usize {
49+
todo!()
50+
}
51+
}

datafusion/functions-aggregate/src/count.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ use std::{
5959
ops::BitAnd,
6060
sync::Arc,
6161
};
62+
use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
6263

6364
make_udaf_expr_and_func!(
6465
Count,
@@ -336,18 +337,27 @@ impl AggregateUDFImpl for Count {
336337
}
337338

338339
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
339-
// groups accumulator only supports `COUNT(c1)`, not
340-
// `COUNT(c1, c2)`, etc
341-
if args.is_distinct {
342-
return false;
343-
}
344340
args.exprs.len() == 1
345341
}
346342

347343
fn create_groups_accumulator(
348344
&self,
349-
_args: AccumulatorArgs,
345+
args: AccumulatorArgs,
350346
) -> Result<Box<dyn GroupsAccumulator>> {
347+
if(args.is_distinct){
348+
let data_type = args.expr_fields[0].data_type();
349+
return match data_type{
350+
DataType::Int8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(data_type.clone()))),
351+
DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(data_type.clone()))),
352+
DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(data_type.clone()))),
353+
DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(data_type.clone()))),
354+
DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(data_type.clone()))),
355+
DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt16Type>::new(data_type.clone()))),
356+
DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt32Type>::new(data_type.clone()))),
357+
DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<UInt64Type>::new(data_type.clone()))),
358+
_ => not_impl_err!("GroupsAccumulator not supported for COUNT(DISTINCT) with {}", data_type),
359+
}
360+
}
351361
// instantiate specialized accumulator
352362
Ok(Box::new(CountGroupsAccumulator::new()))
353363
}

0 commit comments

Comments
 (0)