Skip to content

Commit 8d76cab

Browse files
committed
hashtable_with_count_vector_approach
1 parent 36f232c commit 8d76cab

2 files changed

Lines changed: 80 additions & 86 deletions

File tree

  • datafusion
    • functions-aggregate-common/src/aggregate/count_distinct
    • functions-aggregate/src

datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ where
3434
T::Native: Eq + Hash,
3535
{
3636
seen: HashSet<(usize, T::Native), RandomState>,
37-
num_groups: usize,
37+
counts: Vec<i64>,
3838
}
3939

4040
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
@@ -44,7 +44,7 @@ where
4444
pub fn new() -> Self {
4545
Self {
4646
seen: HashSet::default(),
47-
num_groups: 0,
47+
counts: Vec::new(),
4848
}
4949
}
5050
}
@@ -71,47 +71,40 @@ where
7171
total_num_groups: usize,
7272
) -> datafusion_common::Result<()> {
7373
debug_assert_eq!(values.len(), 1);
74-
self.num_groups = self.num_groups.max(total_num_groups);
74+
self.counts.resize(total_num_groups, 0);
7575
let arr = values[0].as_primitive::<T>();
7676
accumulate(group_indices, arr, opt_filter, |group_idx, value| {
77-
self.seen.insert((group_idx, value));
77+
if self.seen.insert((group_idx, value)) {
78+
self.counts[group_idx] += 1;
79+
}
7880
});
7981
Ok(())
8082
}
8183

8284
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
83-
let num_emitted = match emit_to {
84-
EmitTo::All => self.num_groups,
85-
EmitTo::First(n) => n,
86-
};
85+
let counts = emit_to.take_needed(&mut self.counts);
8786

88-
let mut counts = vec![0i64; num_emitted];
89-
90-
if matches!(emit_to, EmitTo::All) {
91-
for &(group_idx, _) in self.seen.iter() {
92-
counts[group_idx] += 1;
87+
match emit_to {
88+
EmitTo::All => {
89+
self.seen.clear();
9390
}
94-
self.seen.clear();
95-
self.num_groups = 0;
96-
} else {
97-
let mut remaining = HashSet::default();
98-
for (group_idx, value) in self.seen.drain() {
99-
if group_idx < num_emitted {
100-
counts[group_idx] += 1;
101-
} else {
102-
remaining.insert((group_idx - num_emitted, value));
91+
EmitTo::First(n) => {
92+
let mut remaining = HashSet::default();
93+
for (group_idx, value) in self.seen.drain() {
94+
if group_idx >= n {
95+
remaining.insert((group_idx - n, value));
96+
}
10397
}
98+
self.seen = remaining;
10499
}
105-
self.seen = remaining;
106-
self.num_groups = self.num_groups.saturating_sub(num_emitted);
107100
}
108101

109102
Ok(Arc::new(Int64Array::from(counts)))
110103
}
111104

112105
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
113106
let num_emitted = match emit_to {
114-
EmitTo::All => self.num_groups,
107+
EmitTo::All => self.counts.len(),
115108
EmitTo::First(n) => n,
116109
};
117110

@@ -121,7 +114,7 @@ where
121114
for (group_idx, value) in self.seen.drain() {
122115
group_values[group_idx].push(value);
123116
}
124-
self.num_groups = 0;
117+
self.counts.clear();
125118
} else {
126119
let mut remaining = HashSet::default();
127120
for (group_idx, value) in self.seen.drain() {
@@ -132,7 +125,7 @@ where
132125
}
133126
}
134127
self.seen = remaining;
135-
self.num_groups = self.num_groups.saturating_sub(num_emitted);
128+
let _ = emit_to.take_needed(&mut self.counts);
136129
}
137130

138131
let mut offsets = vec![0i32];
@@ -161,14 +154,16 @@ where
161154
total_num_groups: usize,
162155
) -> datafusion_common::Result<()> {
163156
debug_assert_eq!(values.len(), 1);
164-
self.num_groups = self.num_groups.max(total_num_groups);
157+
self.counts.resize(total_num_groups, 0);
165158
let list_array = values[0].as_list::<i32>();
166159

167-
for (row_idx, group_idx) in group_indices.iter().enumerate() {
160+
for (row_idx, &group_idx) in group_indices.iter().enumerate() {
168161
let inner = list_array.value(row_idx);
169162
let inner_arr = inner.as_primitive::<T>();
170-
for value in inner_arr.values().iter() {
171-
self.seen.insert((*group_idx, *value));
163+
for &value in inner_arr.values().iter() {
164+
if self.seen.insert((group_idx, value)) {
165+
self.counts[group_idx] += 1;
166+
}
172167
}
173168
}
174169

@@ -178,5 +173,6 @@ where
178173
fn size(&self) -> usize {
179174
size_of::<Self>()
180175
+ self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
176+
+ self.counts.capacity() * size_of::<i64>()
181177
}
182178
}

datafusion/functions-aggregate/src/count.rs

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -348,69 +348,30 @@ impl AggregateUDFImpl for Count {
348348
if args.exprs.len() != 1 {
349349
return false;
350350
}
351-
if args.is_distinct {
352-
// Only support primitive integer types for now
353-
matches!(
354-
args.expr_fields[0].data_type(),
355-
DataType::Int8
356-
| DataType::Int16
357-
| DataType::Int32
358-
| DataType::Int64
359-
| DataType::UInt8
360-
| DataType::UInt16
361-
| DataType::UInt32
362-
| DataType::UInt64
363-
)
364-
} else {
365-
true
351+
if !args.is_distinct {
352+
return true;
366353
}
354+
matches!(
355+
args.expr_fields[0].data_type(),
356+
DataType::Int8
357+
| DataType::Int16
358+
| DataType::Int32
359+
| DataType::Int64
360+
| DataType::UInt8
361+
| DataType::UInt16
362+
| DataType::UInt32
363+
| DataType::UInt64
364+
)
367365
}
368366

369367
fn create_groups_accumulator(
370368
&self,
371369
args: AccumulatorArgs,
372370
) -> Result<Box<dyn GroupsAccumulator>> {
373-
if args.is_distinct {
374-
let data_type = args.expr_fields[0].data_type();
375-
return match data_type {
376-
DataType::Int8 => Ok(Box::new(
377-
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
378-
)),
379-
DataType::Int16 => Ok(Box::new(
380-
PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(),
381-
)),
382-
DataType::Int32 => Ok(Box::new(
383-
PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(),
384-
)),
385-
DataType::Int64 => Ok(Box::new(
386-
PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(),
387-
)),
388-
DataType::UInt8 => Ok(Box::new(
389-
PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(),
390-
)),
391-
DataType::UInt16 => {
392-
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
393-
UInt16Type,
394-
>::new()))
395-
}
396-
DataType::UInt32 => {
397-
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
398-
UInt32Type,
399-
>::new()))
400-
}
401-
DataType::UInt64 => {
402-
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
403-
UInt64Type,
404-
>::new()))
405-
}
406-
_ => not_impl_err!(
407-
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
408-
data_type
409-
),
410-
};
371+
if !args.is_distinct {
372+
return Ok(Box::new(CountGroupsAccumulator::new()));
411373
}
412-
// instantiate specialized accumulator
413-
Ok(Box::new(CountGroupsAccumulator::new()))
374+
create_distinct_count_groups_accumulator(args)
414375
}
415376

416377
fn reverse_expr(&self) -> ReversedUDAF {
@@ -483,6 +444,43 @@ impl AggregateUDFImpl for Count {
483444
}
484445
}
485446

447+
#[cold]
448+
fn create_distinct_count_groups_accumulator(
449+
args: AccumulatorArgs,
450+
) -> Result<Box<dyn GroupsAccumulator>> {
451+
let data_type = args.expr_fields[0].data_type();
452+
match data_type {
453+
DataType::Int8 => Ok(Box::new(
454+
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
455+
)),
456+
DataType::Int16 => Ok(Box::new(
457+
PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(),
458+
)),
459+
DataType::Int32 => Ok(Box::new(
460+
PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(),
461+
)),
462+
DataType::Int64 => Ok(Box::new(
463+
PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(),
464+
)),
465+
DataType::UInt8 => Ok(Box::new(
466+
PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(),
467+
)),
468+
DataType::UInt16 => Ok(Box::new(
469+
PrimitiveDistinctCountGroupsAccumulator::<UInt16Type>::new(),
470+
)),
471+
DataType::UInt32 => Ok(Box::new(
472+
PrimitiveDistinctCountGroupsAccumulator::<UInt32Type>::new(),
473+
)),
474+
DataType::UInt64 => Ok(Box::new(
475+
PrimitiveDistinctCountGroupsAccumulator::<UInt64Type>::new(),
476+
)),
477+
_ => not_impl_err!(
478+
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
479+
data_type
480+
),
481+
}
482+
}
483+
486484
// DistinctCountAccumulator does not support retract_batch and sliding window
487485
// this is a specialized accumulator for distinct count that supports retract_batch
488486
// and sliding window.

0 commit comments

Comments
 (0)