Skip to content

Commit 929e081

Browse files
committed
hashtable_with_count_vector_approach
1 parent f982d8d commit 929e081

2 files changed

Lines changed: 80 additions & 86 deletions

File tree

  • datafusion
    • functions-aggregate-common/src/aggregate/count_distinct
    • functions-aggregate/src

datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ where
3434
T::Native: Eq + Hash,
3535
{
3636
seen: HashSet<(usize, T::Native), RandomState>,
37-
num_groups: usize,
37+
counts: Vec<i64>,
3838
}
3939

4040
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
@@ -44,7 +44,7 @@ where
4444
pub fn new() -> Self {
4545
Self {
4646
seen: HashSet::default(),
47-
num_groups: 0,
47+
counts: Vec::new(),
4848
}
4949
}
5050
}
@@ -71,47 +71,40 @@ where
7171
total_num_groups: usize,
7272
) -> datafusion_common::Result<()> {
7373
debug_assert_eq!(values.len(), 1);
74-
self.num_groups = self.num_groups.max(total_num_groups);
74+
self.counts.resize(total_num_groups, 0);
7575
let arr = values[0].as_primitive::<T>();
7676
accumulate(group_indices, arr, opt_filter, |group_idx, value| {
77-
self.seen.insert((group_idx, value));
77+
if self.seen.insert((group_idx, value)) {
78+
self.counts[group_idx] += 1;
79+
}
7880
});
7981
Ok(())
8082
}
8183

8284
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
83-
let num_emitted = match emit_to {
84-
EmitTo::All => self.num_groups,
85-
EmitTo::First(n) => n,
86-
};
85+
let counts = emit_to.take_needed(&mut self.counts);
8786

88-
let mut counts = vec![0i64; num_emitted];
89-
90-
if matches!(emit_to, EmitTo::All) {
91-
for &(group_idx, _) in self.seen.iter() {
92-
counts[group_idx] += 1;
87+
match emit_to {
88+
EmitTo::All => {
89+
self.seen.clear();
9390
}
94-
self.seen.clear();
95-
self.num_groups = 0;
96-
} else {
97-
let mut remaining = HashSet::default();
98-
for (group_idx, value) in self.seen.drain() {
99-
if group_idx < num_emitted {
100-
counts[group_idx] += 1;
101-
} else {
102-
remaining.insert((group_idx - num_emitted, value));
91+
EmitTo::First(n) => {
92+
let mut remaining = HashSet::default();
93+
for (group_idx, value) in self.seen.drain() {
94+
if group_idx >= n {
95+
remaining.insert((group_idx - n, value));
96+
}
10397
}
98+
self.seen = remaining;
10499
}
105-
self.seen = remaining;
106-
self.num_groups = self.num_groups.saturating_sub(num_emitted);
107100
}
108101

109102
Ok(Arc::new(Int64Array::from(counts)))
110103
}
111104

112105
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
113106
let num_emitted = match emit_to {
114-
EmitTo::All => self.num_groups,
107+
EmitTo::All => self.counts.len(),
115108
EmitTo::First(n) => n,
116109
};
117110

@@ -121,7 +114,7 @@ where
121114
for (group_idx, value) in self.seen.drain() {
122115
group_values[group_idx].push(value);
123116
}
124-
self.num_groups = 0;
117+
self.counts.clear();
125118
} else {
126119
let mut remaining = HashSet::default();
127120
for (group_idx, value) in self.seen.drain() {
@@ -132,7 +125,7 @@ where
132125
}
133126
}
134127
self.seen = remaining;
135-
self.num_groups = self.num_groups.saturating_sub(num_emitted);
128+
let _ = emit_to.take_needed(&mut self.counts);
136129
}
137130

138131
let mut offsets = vec![0i32];
@@ -161,14 +154,16 @@ where
161154
total_num_groups: usize,
162155
) -> datafusion_common::Result<()> {
163156
debug_assert_eq!(values.len(), 1);
164-
self.num_groups = self.num_groups.max(total_num_groups);
157+
self.counts.resize(total_num_groups, 0);
165158
let list_array = values[0].as_list::<i32>();
166159

167-
for (row_idx, group_idx) in group_indices.iter().enumerate() {
160+
for (row_idx, &group_idx) in group_indices.iter().enumerate() {
168161
let inner = list_array.value(row_idx);
169162
let inner_arr = inner.as_primitive::<T>();
170-
for value in inner_arr.values().iter() {
171-
self.seen.insert((*group_idx, *value));
163+
for &value in inner_arr.values().iter() {
164+
if self.seen.insert((group_idx, value)) {
165+
self.counts[group_idx] += 1;
166+
}
172167
}
173168
}
174169

@@ -178,5 +173,6 @@ where
178173
fn size(&self) -> usize {
179174
size_of::<Self>()
180175
+ self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
176+
+ self.counts.capacity() * size_of::<i64>()
181177
}
182178
}

datafusion/functions-aggregate/src/count.rs

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -340,69 +340,30 @@ impl AggregateUDFImpl for Count {
340340
if args.exprs.len() != 1 {
341341
return false;
342342
}
343-
if args.is_distinct {
344-
// Only support primitive integer types for now
345-
matches!(
346-
args.expr_fields[0].data_type(),
347-
DataType::Int8
348-
| DataType::Int16
349-
| DataType::Int32
350-
| DataType::Int64
351-
| DataType::UInt8
352-
| DataType::UInt16
353-
| DataType::UInt32
354-
| DataType::UInt64
355-
)
356-
} else {
357-
true
343+
if !args.is_distinct {
344+
return true;
358345
}
346+
matches!(
347+
args.expr_fields[0].data_type(),
348+
DataType::Int8
349+
| DataType::Int16
350+
| DataType::Int32
351+
| DataType::Int64
352+
| DataType::UInt8
353+
| DataType::UInt16
354+
| DataType::UInt32
355+
| DataType::UInt64
356+
)
359357
}
360358

361359
fn create_groups_accumulator(
362360
&self,
363361
args: AccumulatorArgs,
364362
) -> Result<Box<dyn GroupsAccumulator>> {
365-
if args.is_distinct {
366-
let data_type = args.expr_fields[0].data_type();
367-
return match data_type {
368-
DataType::Int8 => Ok(Box::new(
369-
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
370-
)),
371-
DataType::Int16 => Ok(Box::new(
372-
PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(),
373-
)),
374-
DataType::Int32 => Ok(Box::new(
375-
PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(),
376-
)),
377-
DataType::Int64 => Ok(Box::new(
378-
PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(),
379-
)),
380-
DataType::UInt8 => Ok(Box::new(
381-
PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(),
382-
)),
383-
DataType::UInt16 => {
384-
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
385-
UInt16Type,
386-
>::new()))
387-
}
388-
DataType::UInt32 => {
389-
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
390-
UInt32Type,
391-
>::new()))
392-
}
393-
DataType::UInt64 => {
394-
Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
395-
UInt64Type,
396-
>::new()))
397-
}
398-
_ => not_impl_err!(
399-
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
400-
data_type
401-
),
402-
};
363+
if !args.is_distinct {
364+
return Ok(Box::new(CountGroupsAccumulator::new()));
403365
}
404-
// instantiate specialized accumulator
405-
Ok(Box::new(CountGroupsAccumulator::new()))
366+
create_distinct_count_groups_accumulator(args)
406367
}
407368

408369
fn reverse_expr(&self) -> ReversedUDAF {
@@ -475,6 +436,43 @@ impl AggregateUDFImpl for Count {
475436
}
476437
}
477438

439+
#[cold]
440+
fn create_distinct_count_groups_accumulator(
441+
args: AccumulatorArgs,
442+
) -> Result<Box<dyn GroupsAccumulator>> {
443+
let data_type = args.expr_fields[0].data_type();
444+
match data_type {
445+
DataType::Int8 => Ok(Box::new(
446+
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
447+
)),
448+
DataType::Int16 => Ok(Box::new(
449+
PrimitiveDistinctCountGroupsAccumulator::<Int16Type>::new(),
450+
)),
451+
DataType::Int32 => Ok(Box::new(
452+
PrimitiveDistinctCountGroupsAccumulator::<Int32Type>::new(),
453+
)),
454+
DataType::Int64 => Ok(Box::new(
455+
PrimitiveDistinctCountGroupsAccumulator::<Int64Type>::new(),
456+
)),
457+
DataType::UInt8 => Ok(Box::new(
458+
PrimitiveDistinctCountGroupsAccumulator::<UInt8Type>::new(),
459+
)),
460+
DataType::UInt16 => Ok(Box::new(
461+
PrimitiveDistinctCountGroupsAccumulator::<UInt16Type>::new(),
462+
)),
463+
DataType::UInt32 => Ok(Box::new(
464+
PrimitiveDistinctCountGroupsAccumulator::<UInt32Type>::new(),
465+
)),
466+
DataType::UInt64 => Ok(Box::new(
467+
PrimitiveDistinctCountGroupsAccumulator::<UInt64Type>::new(),
468+
)),
469+
_ => not_impl_err!(
470+
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
471+
data_type
472+
),
473+
}
474+
}
475+
478476
// DistinctCountAccumulator does not support retract_batch and sliding window
479477
// this is a specialized accumulator for distinct count that supports retract_batch
480478
// and sliding window.

0 commit comments

Comments
 (0)