Skip to content

Commit ee0c865

Browse files
committed
implement_group_accumulators_count_distinct_use_hashtable
1 parent f0b2a4a commit ee0c865

File tree

1 file changed

+72
-50
lines changed
  • datafusion/functions-aggregate-common/src/aggregate/count_distinct

1 file changed

+72
-50
lines changed

datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,71 @@
1616
// under the License.
1717

1818
use arrow::array::{
19-
Array, ArrayRef, AsArray, BooleanArray, Int64Array, ListBuilder, PrimitiveBuilder,
19+
ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray,
2020
};
21-
use arrow::datatypes::ArrowPrimitiveType;
21+
use arrow::buffer::OffsetBuffer;
22+
use arrow::datatypes::{ArrowPrimitiveType, Field};
2223
use datafusion_common::HashSet;
2324
use datafusion_common::hash_utils::RandomState;
2425
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
2526
use std::hash::Hash;
2627
use std::mem::size_of;
2728
use std::sync::Arc;
2829

30+
use crate::aggregate::groups_accumulator::accumulate::accumulate;
31+
2932
pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
3033
where
3134
T::Native: Eq + Hash,
3235
{
33-
/// Count distinct per group.
34-
values: Vec<HashSet<T::Native, RandomState>>,
36+
seen: HashSet<(usize, T::Native), RandomState>,
37+
num_groups: usize,
3538
}
3639

37-
impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
40+
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
3841
where
3942
T::Native: Eq + Hash,
4043
{
41-
fn default() -> Self {
42-
Self::new()
44+
pub fn new() -> Self {
45+
Self {
46+
seen: HashSet::default(),
47+
num_groups: 0,
48+
}
49+
}
50+
51+
fn emit_to_values(&mut self, emit_to: EmitTo) -> Vec<Vec<T::Native>> {
52+
let num_emitted = match emit_to {
53+
EmitTo::All => self.num_groups,
54+
EmitTo::First(n) => n,
55+
};
56+
57+
let mut group_values: Vec<Vec<T::Native>> = vec![Vec::new(); num_emitted];
58+
let mut remaining = HashSet::default();
59+
60+
for (group_idx, value) in self.seen.drain() {
61+
if group_idx < num_emitted {
62+
group_values[group_idx].push(value);
63+
} else {
64+
remaining.insert((group_idx - num_emitted, value));
65+
}
66+
}
67+
68+
self.seen = remaining;
69+
match emit_to {
70+
EmitTo::All => self.num_groups = 0,
71+
EmitTo::First(n) => self.num_groups = self.num_groups.saturating_sub(n),
72+
}
73+
74+
group_values
4375
}
4476
}
4577

46-
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
78+
impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
4779
where
4880
T::Native: Eq + Hash,
4981
{
50-
pub fn new() -> Self {
51-
Self { values: Vec::new() }
82+
fn default() -> Self {
83+
Self::new()
5284
}
5385
}
5486

@@ -64,47 +96,40 @@ where
6496
opt_filter: Option<&BooleanArray>,
6597
total_num_groups: usize,
6698
) -> datafusion_common::Result<()> {
67-
self.values.resize_with(total_num_groups, HashSet::default);
68-
debug_assert_eq!(values.len(), 1, "multiple arguments are not supported");
69-
99+
debug_assert_eq!(values.len(), 1);
100+
self.num_groups = self.num_groups.max(total_num_groups);
70101
let arr = values[0].as_primitive::<T>();
71-
72-
for (idx, group_idx) in group_indices.iter().enumerate() {
73-
if let Some(filter) = opt_filter
74-
&& !filter.value(idx)
75-
{
76-
continue;
77-
}
78-
if arr.is_valid(idx) {
79-
let value = arr.value(idx);
80-
self.values[*group_idx].insert(value);
81-
}
82-
}
83-
102+
accumulate(group_indices, arr, opt_filter, |group_idx, value| {
103+
self.seen.insert((group_idx, value));
104+
});
84105
Ok(())
85106
}
86107

87108
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
88-
let counts: Vec<i64> = emit_to
89-
.take_needed(&mut self.values)
90-
.iter()
91-
.map(|groups| groups.len() as i64)
92-
.collect();
93-
109+
let group_values = self.emit_to_values(emit_to);
110+
let counts: Vec<i64> = group_values.iter().map(|v| v.len() as i64).collect();
94111
Ok(Arc::new(Int64Array::from(counts)))
95112
}
96113

97114
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
98-
let hash_sets = emit_to.take_needed(&mut self.values);
99-
let mut builder = ListBuilder::new(PrimitiveBuilder::<T>::new());
115+
let group_values = self.emit_to_values(emit_to);
100116

101-
for set in hash_sets {
102-
for value in set {
103-
builder.values().append_value(value);
104-
}
105-
builder.append(true);
117+
let mut offsets = vec![0i32];
118+
let mut all_values = Vec::new();
119+
for values in &group_values {
120+
all_values.extend(values.iter().copied());
121+
offsets.push(all_values.len() as i32);
106122
}
107-
Ok(vec![Arc::new(builder.finish())])
123+
124+
let values_array = Arc::new(PrimitiveArray::<T>::from_iter_values(all_values));
125+
let list_array = ListArray::new(
126+
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
127+
OffsetBuffer::new(offsets.into()),
128+
values_array,
129+
None,
130+
);
131+
132+
Ok(vec![Arc::new(list_array)])
108133
}
109134

110135
fn merge_batch(
@@ -114,26 +139,23 @@ where
114139
_opt_filter: Option<&BooleanArray>,
115140
total_num_groups: usize,
116141
) -> datafusion_common::Result<()> {
117-
self.values.resize_with(total_num_groups, HashSet::default);
142+
debug_assert_eq!(values.len(), 1);
143+
self.num_groups = self.num_groups.max(total_num_groups);
118144
let list_array = values[0].as_list::<i32>();
119145

120146
for (row_idx, group_idx) in group_indices.iter().enumerate() {
121147
let inner = list_array.value(row_idx);
122-
let inner_set = inner.as_primitive::<T>();
123-
for i in 0..inner.len() {
124-
self.values[*group_idx].insert(inner_set.value(i));
148+
let inner_arr = inner.as_primitive::<T>();
149+
for value in inner_arr.values().iter() {
150+
self.seen.insert((*group_idx, *value));
125151
}
126152
}
153+
127154
Ok(())
128155
}
129156

130157
fn size(&self) -> usize {
131158
size_of::<Self>()
132-
+ self.values.capacity() * size_of::<HashSet<T::Native, RandomState>>()
133-
+ self
134-
.values
135-
.iter()
136-
.map(|s| s.capacity() * size_of::<T::Native>())
137-
.sum::<usize>()
159+
+ self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
138160
}
139161
}

0 commit comments

Comments
 (0)