Skip to content

Commit e5012b4

Browse files
committed
saturating sub for unsigned
1 parent f9cfb0d commit e5012b4

1 file changed

Lines changed: 105 additions & 26 deletions

File tree

datafusion/expr/src/window_state.rs

Lines changed: 105 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ use std::{collections::VecDeque, ops::Range, sync::Arc};
2222
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
2323

2424
use arrow::{
25-
array::{ArrayRef, make_comparator},
25+
array::{ArrayRef, AsArray, PrimitiveArray, make_comparator},
2626
compute::{
2727
SortOptions, concat, concat_batches,
2828
kernels::numeric::{add_wrapping, sub_wrapping},
2929
},
30-
datatypes::{DataType, SchemaRef},
30+
datatypes::{DataType, SchemaRef, UInt8Type, UInt16Type, UInt32Type, UInt64Type},
3131
record_batch::RecordBatch,
3232
};
3333
use datafusion_common::{
@@ -298,7 +298,6 @@ impl PartitionBatchState {
298298
}
299299
}
300300

301-
302301
type SharedDynComparator = Arc<dyn Fn(usize, usize) -> std::cmp::Ordering + Send + Sync>;
303302

304303
/// Holds pre-computed comparators for finding RANGE window frame boundaries for all rows in the batch.
@@ -309,7 +308,9 @@ struct WindowRangeComparator {
309308

310309
impl std::fmt::Debug for WindowRangeComparator {
311310
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312-
f.debug_struct("WindowRangeComparator").field("comparators", &self.comparators.len()).finish()
311+
f.debug_struct("WindowRangeComparator")
312+
.field("comparators", &self.comparators.len())
313+
.finish()
313314
}
314315
}
315316

@@ -388,9 +389,9 @@ impl WindowRangeComparator {
388389
}
389390

390391
/// Computes array for a given bound.
391-
/// For PRECEDING with descending=false: bound = value + delta
392-
/// For PRECEDING with descending=true: bound = value - delta
393-
/// For FOLLOWING with descending=false: bound = value + delta
392+
/// For PRECEDING with descending=false: bound = value - delta
393+
/// For PRECEDING with descending=true: bound = value + delta
394+
/// For FOLLOWING with descending=false: bound = value + delta
394395
/// For FOLLOWING with descending=true: bound = value - delta
395396
fn compute_bound_array(
396397
range_column: &ArrayRef,
@@ -405,11 +406,91 @@ impl WindowRangeComparator {
405406
let result = if add {
406407
add_wrapping(range_column, &delta_scalar)
407408
} else {
408-
sub_wrapping(range_column, &delta_scalar)
409+
if let Some(result) = Self::saturating_sub_unsigned_array(range_column, delta)
410+
{
411+
Ok(result)
412+
} else {
413+
sub_wrapping(range_column, &delta_scalar)
414+
}
409415
};
410416
result.map_err(|e| internal_datafusion_err!("Failed to compute bound array: {e}"))
411417
}
412418

419+
fn saturating_sub_unsigned_array(
420+
range_column: &ArrayRef,
421+
delta: &ScalarValue,
422+
) -> Option<ArrayRef> {
423+
match (range_column.data_type(), delta) {
424+
(DataType::UInt8, ScalarValue::UInt8(Some(delta))) => {
425+
let result: PrimitiveArray<UInt8Type> = range_column
426+
.as_primitive::<UInt8Type>()
427+
.unary(|value| value.saturating_sub(*delta));
428+
Some(Arc::new(result))
429+
}
430+
(DataType::UInt16, ScalarValue::UInt16(Some(delta))) => {
431+
let result: PrimitiveArray<UInt16Type> = range_column
432+
.as_primitive::<UInt16Type>()
433+
.unary(|value| value.saturating_sub(*delta));
434+
Some(Arc::new(result))
435+
}
436+
(DataType::UInt16, ScalarValue::UInt8(Some(delta))) => {
437+
let delta = *delta as u16;
438+
let result: PrimitiveArray<UInt16Type> = range_column
439+
.as_primitive::<UInt16Type>()
440+
.unary(|value| value.saturating_sub(delta));
441+
Some(Arc::new(result))
442+
}
443+
(DataType::UInt32, ScalarValue::UInt32(Some(delta))) => {
444+
let result: PrimitiveArray<UInt32Type> = range_column
445+
.as_primitive::<UInt32Type>()
446+
.unary(|value| value.saturating_sub(*delta));
447+
Some(Arc::new(result))
448+
}
449+
(DataType::UInt32, ScalarValue::UInt16(Some(delta))) => {
450+
let delta = *delta as u32;
451+
let result: PrimitiveArray<UInt32Type> = range_column
452+
.as_primitive::<UInt32Type>()
453+
.unary(|value| value.saturating_sub(delta));
454+
Some(Arc::new(result))
455+
}
456+
(DataType::UInt32, ScalarValue::UInt8(Some(delta))) => {
457+
let delta = *delta as u32;
458+
let result: PrimitiveArray<UInt32Type> = range_column
459+
.as_primitive::<UInt32Type>()
460+
.unary(|value| value.saturating_sub(delta));
461+
Some(Arc::new(result))
462+
}
463+
(DataType::UInt64, ScalarValue::UInt64(Some(delta))) => {
464+
let result: PrimitiveArray<UInt64Type> = range_column
465+
.as_primitive::<UInt64Type>()
466+
.unary(|value| value.saturating_sub(*delta));
467+
Some(Arc::new(result))
468+
}
469+
(DataType::UInt64, ScalarValue::UInt32(Some(delta))) => {
470+
let delta = *delta as u64;
471+
let result: PrimitiveArray<UInt64Type> = range_column
472+
.as_primitive::<UInt64Type>()
473+
.unary(|value| value.saturating_sub(delta));
474+
Some(Arc::new(result))
475+
}
476+
(DataType::UInt64, ScalarValue::UInt16(Some(delta))) => {
477+
let delta = *delta as u64;
478+
let result: PrimitiveArray<UInt64Type> = range_column
479+
.as_primitive::<UInt64Type>()
480+
.unary(|value| value.saturating_sub(delta));
481+
Some(Arc::new(result))
482+
}
483+
(DataType::UInt64, ScalarValue::UInt8(Some(delta))) => {
484+
let delta = *delta as u64;
485+
let result: PrimitiveArray<UInt64Type> = range_column
486+
.as_primitive::<UInt64Type>()
487+
.unary(|value| value.saturating_sub(delta));
488+
Some(Arc::new(result))
489+
}
490+
_ => None,
491+
}
492+
}
493+
413494
fn compare(&self, search_idx: usize, current_idx: usize) -> std::cmp::Ordering {
414495
for comparator in &self.comparators {
415496
let cmp = comparator(search_idx, current_idx);
@@ -447,8 +528,16 @@ impl WindowFrameStateRange {
447528
window_frame: &Arc<WindowFrame>,
448529
range_columns: &[ArrayRef],
449530
) -> Result<()> {
450-
self.start_bound_comparator = WindowRangeComparator::try_build(&window_frame.start_bound, range_columns, &self.sort_options)?;
451-
self.end_bound_comparator = WindowRangeComparator::try_build(&window_frame.end_bound, range_columns, &self.sort_options)?;
531+
self.start_bound_comparator = WindowRangeComparator::try_build(
532+
&window_frame.start_bound,
533+
range_columns,
534+
&self.sort_options,
535+
)?;
536+
self.end_bound_comparator = WindowRangeComparator::try_build(
537+
&window_frame.end_bound,
538+
range_columns,
539+
&self.sort_options,
540+
)?;
452541
Ok(())
453542
}
454543

@@ -464,10 +553,8 @@ impl WindowFrameStateRange {
464553
WindowFrameBound::Preceding(_)
465554
| WindowFrameBound::CurrentRow
466555
| WindowFrameBound::Following(_) => {
467-
let comparator = self
468-
.start_bound_comparator
469-
.as_ref()
470-
.ok_or_else(|| {
556+
let comparator =
557+
self.start_bound_comparator.as_ref().ok_or_else(|| {
471558
internal_datafusion_err!("Missing start_bound comparator")
472559
})?;
473560
self.search_index_of_row::<true>(
@@ -484,18 +571,10 @@ impl WindowFrameStateRange {
484571
WindowFrameBound::Preceding(_)
485572
| WindowFrameBound::CurrentRow
486573
| WindowFrameBound::Following(_) => {
487-
let comparator = self
488-
.end_bound_comparator
489-
.as_ref()
490-
.ok_or_else(|| {
491-
internal_datafusion_err!("Missing end_bound comparator")
492-
})?;
493-
self.search_index_of_row::<false>(
494-
comparator,
495-
last_range.end,
496-
length,
497-
idx,
498-
)
574+
let comparator = self.end_bound_comparator.as_ref().ok_or_else(|| {
575+
internal_datafusion_err!("Missing end_bound comparator")
576+
})?;
577+
self.search_index_of_row::<false>(comparator, last_range.end, length, idx)
499578
}
500579
};
501580
Ok(Range { start, end })

0 commit comments

Comments
 (0)