Skip to content

Commit bab467e

Browse files
committed
address comment
1 parent 779706c commit bab467e

3 files changed

Lines changed: 226 additions & 118 deletions

File tree

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 138 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::mem::size_of_val;
2424
use std::sync::Arc;
2525

2626
use arrow::array::{
27-
Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
27+
Array, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder,
2828
};
2929
use arrow::buffer::BooleanBuffer;
3030
use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
@@ -79,31 +79,10 @@ pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
7979
.unwrap()
8080
}
8181

82-
fn create_groups_primitive_accumulator<T: ArrowPrimitiveType + Send>(
83-
args: &AccumulatorArgs,
84-
is_first: bool,
85-
) -> Result<Box<dyn GroupsAccumulator>> {
86-
let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
87-
return internal_err!("Groups accumulator must have an ordering.");
88-
};
89-
90-
let ordering_dtypes = ordering
91-
.iter()
92-
.map(|e| e.expr.data_type(args.schema))
93-
.collect::<Result<Vec<_>>>()?;
94-
95-
Ok(Box::new(FirstLastGroupsAccumulator::try_new(
96-
PrimitiveValueState::<T>::new(args.return_field.data_type().clone()),
97-
ordering,
98-
args.ignore_nulls,
99-
&ordering_dtypes,
100-
is_first,
101-
)?))
102-
}
103-
104-
fn create_groups_bytes_accumulator(
82+
fn create_groups_accumulator_helper<S: ValueState + 'static>(
10583
args: &AccumulatorArgs,
10684
is_first: bool,
85+
state: S,
10786
) -> Result<Box<dyn GroupsAccumulator>> {
10887
let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
10988
return internal_err!("Groups accumulator must have an ordering.");
@@ -115,7 +94,7 @@ fn create_groups_bytes_accumulator(
11594
.collect::<Result<Vec<_>>>()?;
11695

11796
Ok(Box::new(FirstLastGroupsAccumulator::try_new(
118-
BytesValueState::try_new(args.return_field.data_type().clone())?,
97+
state,
11998
ordering,
12099
args.ignore_nulls,
121100
&ordering_dtypes,
@@ -128,99 +107,152 @@ fn create_groups_accumulator(
128107
is_first: bool,
129108
function_name: &str,
130109
) -> Result<Box<dyn GroupsAccumulator>> {
131-
match args.return_field.data_type() {
132-
DataType::Int8 => create_groups_primitive_accumulator::<Int8Type>(args, is_first),
133-
DataType::Int16 => {
134-
create_groups_primitive_accumulator::<Int16Type>(args, is_first)
135-
}
136-
DataType::Int32 => {
137-
create_groups_primitive_accumulator::<Int32Type>(args, is_first)
138-
}
139-
DataType::Int64 => {
140-
create_groups_primitive_accumulator::<Int64Type>(args, is_first)
141-
}
142-
DataType::UInt8 => {
143-
create_groups_primitive_accumulator::<UInt8Type>(args, is_first)
144-
}
145-
DataType::UInt16 => {
146-
create_groups_primitive_accumulator::<UInt16Type>(args, is_first)
147-
}
148-
DataType::UInt32 => {
149-
create_groups_primitive_accumulator::<UInt32Type>(args, is_first)
150-
}
151-
DataType::UInt64 => {
152-
create_groups_primitive_accumulator::<UInt64Type>(args, is_first)
153-
}
154-
DataType::Float16 => {
155-
create_groups_primitive_accumulator::<Float16Type>(args, is_first)
156-
}
157-
DataType::Float32 => {
158-
create_groups_primitive_accumulator::<Float32Type>(args, is_first)
159-
}
160-
DataType::Float64 => {
161-
create_groups_primitive_accumulator::<Float64Type>(args, is_first)
162-
}
110+
let data_type = args.return_field.data_type();
111+
match data_type {
112+
DataType::Int8 => create_groups_accumulator_helper(
113+
args,
114+
is_first,
115+
PrimitiveValueState::<Int8Type>::new(data_type.clone()),
116+
),
117+
DataType::Int16 => create_groups_accumulator_helper(
118+
args,
119+
is_first,
120+
PrimitiveValueState::<Int16Type>::new(data_type.clone()),
121+
),
122+
DataType::Int32 => create_groups_accumulator_helper(
123+
args,
124+
is_first,
125+
PrimitiveValueState::<Int32Type>::new(data_type.clone()),
126+
),
127+
DataType::Int64 => create_groups_accumulator_helper(
128+
args,
129+
is_first,
130+
PrimitiveValueState::<Int64Type>::new(data_type.clone()),
131+
),
132+
DataType::UInt8 => create_groups_accumulator_helper(
133+
args,
134+
is_first,
135+
PrimitiveValueState::<UInt8Type>::new(data_type.clone()),
136+
),
137+
DataType::UInt16 => create_groups_accumulator_helper(
138+
args,
139+
is_first,
140+
PrimitiveValueState::<UInt16Type>::new(data_type.clone()),
141+
),
142+
DataType::UInt32 => create_groups_accumulator_helper(
143+
args,
144+
is_first,
145+
PrimitiveValueState::<UInt32Type>::new(data_type.clone()),
146+
),
147+
DataType::UInt64 => create_groups_accumulator_helper(
148+
args,
149+
is_first,
150+
PrimitiveValueState::<UInt64Type>::new(data_type.clone()),
151+
),
152+
DataType::Float16 => create_groups_accumulator_helper(
153+
args,
154+
is_first,
155+
PrimitiveValueState::<Float16Type>::new(data_type.clone()),
156+
),
157+
DataType::Float32 => create_groups_accumulator_helper(
158+
args,
159+
is_first,
160+
PrimitiveValueState::<Float32Type>::new(data_type.clone()),
161+
),
162+
DataType::Float64 => create_groups_accumulator_helper(
163+
args,
164+
is_first,
165+
PrimitiveValueState::<Float64Type>::new(data_type.clone()),
166+
),
163167

164-
DataType::Decimal32(_, _) => {
165-
create_groups_primitive_accumulator::<Decimal32Type>(args, is_first)
166-
}
167-
DataType::Decimal64(_, _) => {
168-
create_groups_primitive_accumulator::<Decimal64Type>(args, is_first)
169-
}
170-
DataType::Decimal128(_, _) => {
171-
create_groups_primitive_accumulator::<Decimal128Type>(args, is_first)
172-
}
173-
DataType::Decimal256(_, _) => {
174-
create_groups_primitive_accumulator::<Decimal256Type>(args, is_first)
175-
}
168+
DataType::Decimal32(_, _) => create_groups_accumulator_helper(
169+
args,
170+
is_first,
171+
PrimitiveValueState::<Decimal32Type>::new(data_type.clone()),
172+
),
173+
DataType::Decimal64(_, _) => create_groups_accumulator_helper(
174+
args,
175+
is_first,
176+
PrimitiveValueState::<Decimal64Type>::new(data_type.clone()),
177+
),
178+
DataType::Decimal128(_, _) => create_groups_accumulator_helper(
179+
args,
180+
is_first,
181+
PrimitiveValueState::<Decimal128Type>::new(data_type.clone()),
182+
),
183+
DataType::Decimal256(_, _) => create_groups_accumulator_helper(
184+
args,
185+
is_first,
186+
PrimitiveValueState::<Decimal256Type>::new(data_type.clone()),
187+
),
176188

177-
DataType::Timestamp(TimeUnit::Second, _) => {
178-
create_groups_primitive_accumulator::<TimestampSecondType>(args, is_first)
179-
}
180-
DataType::Timestamp(TimeUnit::Millisecond, _) => {
181-
create_groups_primitive_accumulator::<TimestampMillisecondType>(
182-
args, is_first,
183-
)
184-
}
185-
DataType::Timestamp(TimeUnit::Microsecond, _) => {
186-
create_groups_primitive_accumulator::<TimestampMicrosecondType>(
187-
args, is_first,
188-
)
189-
}
190-
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
191-
create_groups_primitive_accumulator::<TimestampNanosecondType>(args, is_first)
192-
}
189+
DataType::Timestamp(TimeUnit::Second, _) => create_groups_accumulator_helper(
190+
args,
191+
is_first,
192+
PrimitiveValueState::<TimestampSecondType>::new(data_type.clone()),
193+
),
194+
DataType::Timestamp(TimeUnit::Millisecond, _) => create_groups_accumulator_helper(
195+
args,
196+
is_first,
197+
PrimitiveValueState::<TimestampMillisecondType>::new(data_type.clone()),
198+
),
199+
DataType::Timestamp(TimeUnit::Microsecond, _) => create_groups_accumulator_helper(
200+
args,
201+
is_first,
202+
PrimitiveValueState::<TimestampMicrosecondType>::new(data_type.clone()),
203+
),
204+
DataType::Timestamp(TimeUnit::Nanosecond, _) => create_groups_accumulator_helper(
205+
args,
206+
is_first,
207+
PrimitiveValueState::<TimestampNanosecondType>::new(data_type.clone()),
208+
),
193209

194-
DataType::Date32 => {
195-
create_groups_primitive_accumulator::<Date32Type>(args, is_first)
196-
}
197-
DataType::Date64 => {
198-
create_groups_primitive_accumulator::<Date64Type>(args, is_first)
199-
}
200-
DataType::Time32(TimeUnit::Second) => {
201-
create_groups_primitive_accumulator::<Time32SecondType>(args, is_first)
202-
}
203-
DataType::Time32(TimeUnit::Millisecond) => {
204-
create_groups_primitive_accumulator::<Time32MillisecondType>(args, is_first)
205-
}
206-
DataType::Time64(TimeUnit::Microsecond) => {
207-
create_groups_primitive_accumulator::<Time64MicrosecondType>(args, is_first)
208-
}
209-
DataType::Time64(TimeUnit::Nanosecond) => {
210-
create_groups_primitive_accumulator::<Time64NanosecondType>(args, is_first)
211-
}
210+
DataType::Date32 => create_groups_accumulator_helper(
211+
args,
212+
is_first,
213+
PrimitiveValueState::<Date32Type>::new(data_type.clone()),
214+
),
215+
DataType::Date64 => create_groups_accumulator_helper(
216+
args,
217+
is_first,
218+
PrimitiveValueState::<Date64Type>::new(data_type.clone()),
219+
),
220+
DataType::Time32(TimeUnit::Second) => create_groups_accumulator_helper(
221+
args,
222+
is_first,
223+
PrimitiveValueState::<Time32SecondType>::new(data_type.clone()),
224+
),
225+
DataType::Time32(TimeUnit::Millisecond) => create_groups_accumulator_helper(
226+
args,
227+
is_first,
228+
PrimitiveValueState::<Time32MillisecondType>::new(data_type.clone()),
229+
),
230+
DataType::Time64(TimeUnit::Microsecond) => create_groups_accumulator_helper(
231+
args,
232+
is_first,
233+
PrimitiveValueState::<Time64MicrosecondType>::new(data_type.clone()),
234+
),
235+
DataType::Time64(TimeUnit::Nanosecond) => create_groups_accumulator_helper(
236+
args,
237+
is_first,
238+
PrimitiveValueState::<Time64NanosecondType>::new(data_type.clone()),
239+
),
212240

213241
DataType::Utf8
214242
| DataType::LargeUtf8
215243
| DataType::Utf8View
216244
| DataType::Binary
217245
| DataType::LargeBinary
218-
| DataType::BinaryView => create_groups_bytes_accumulator(args, is_first),
246+
| DataType::BinaryView => create_groups_accumulator_helper(
247+
args,
248+
is_first,
249+
BytesValueState::try_new(data_type.clone())?,
250+
),
219251

220252
_ => internal_err!(
221253
"GroupsAccumulator not supported for {}({})",
222254
function_name,
223-
args.return_field.data_type()
255+
data_type
224256
),
225257
}
226258
}
@@ -419,7 +451,7 @@ struct FirstLastGroupsAccumulator<S: ValueState> {
419451
// to avoid calling `ScalarValue::size_of_vec` by Self.size.
420452
size_of_orderings: usize,
421453

422-
// buffer for `get_filtered_min_of_each_group`
454+
// buffer for `get_filtered_extreme_of_each_group`
423455
// filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
424456
// only valid if filter_min_of_each_group_buf.1[group_idx] == true
425457
extreme_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),

datafusion/functions-aggregate/src/first_last/state.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ pub(crate) trait ValueState: Send + Sync {
3636
/// Note: While this is not a batch interface, it is not a performance bottleneck.
3737
/// In heavy aggregation benchmarks, the overhead of this method is typically less than 1%.
3838
///
39-
/// Benchmarked queries with < 1% `update` overhead:
4039
/// ```sql
4140
/// -- TPC-H SF10
4241
/// select l_shipmode, last_value(l_partkey order by l_orderkey, l_linenumber, l_comment, l_suppkey, l_tax)
@@ -151,6 +150,10 @@ impl ValueState for BytesValueState {
151150
}
152151

153152
fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()> {
153+
if let Some(v) = &self.vals[group_idx] {
154+
self.total_capacity -= v.capacity();
155+
}
156+
154157
if array.is_null(idx) {
155158
self.vals[group_idx] = None;
156159
} else {
@@ -170,16 +173,17 @@ impl ValueState for BytesValueState {
170173
};
171174

172175
if let Some(v) = &mut self.vals[group_idx] {
173-
self.total_capacity -= v.capacity();
174176
v.clear();
175177
v.extend_from_slice(val);
176178

177-
self.total_capacity += v.capacity();
178179
} else {
179180
let v = val.to_vec();
180-
self.total_capacity += v.capacity();
181181
self.vals[group_idx] = Some(v);
182182
}
183+
184+
self.vals[group_idx]
185+
.as_ref()
186+
.inspect(|x| self.total_capacity += x.capacity());
183187
}
184188
Ok(())
185189
}
@@ -436,4 +440,28 @@ mod tests {
436440

437441
Ok(())
438442
}
443+
444+
#[test]
445+
fn test_bytes_value_state_update_null() -> Result<()> {
446+
let mut state = BytesValueState::try_new(DataType::Utf8)?;
447+
state.resize(1);
448+
449+
let array: ArrayRef = Arc::new(StringArray::from(vec![Some("hello"), None]));
450+
451+
// group 0 = "hello"
452+
state.update(0, &array, 0)?;
453+
assert_eq!(state.total_capacity, state.total_capacity_calculated());
454+
assert!(state.total_capacity > 0);
455+
456+
// group 0 = NULL
457+
state.update(0, &array, 1)?;
458+
assert_eq!(
459+
state.total_capacity,
460+
state.total_capacity_calculated(),
461+
"total_capacity should match calculated capacity after update(NULL)"
462+
);
463+
assert_eq!(state.total_capacity, 0);
464+
465+
Ok(())
466+
}
439467
}

0 commit comments

Comments
 (0)