@@ -22,12 +22,12 @@ use std::{collections::VecDeque, ops::Range, sync::Arc};
2222use crate :: { WindowFrame , WindowFrameBound , WindowFrameUnits } ;
2323
2424use arrow:: {
25- array:: { ArrayRef , make_comparator} ,
25+ array:: { ArrayRef , AsArray , PrimitiveArray , make_comparator} ,
2626 compute:: {
2727 SortOptions , concat, concat_batches,
2828 kernels:: numeric:: { add_wrapping, sub_wrapping} ,
2929 } ,
30- datatypes:: { DataType , SchemaRef } ,
30+ datatypes:: { DataType , SchemaRef , UInt8Type , UInt16Type , UInt32Type , UInt64Type } ,
3131 record_batch:: RecordBatch ,
3232} ;
3333use datafusion_common:: {
@@ -298,7 +298,6 @@ impl PartitionBatchState {
298298 }
299299}
300300
301-
302301type SharedDynComparator = Arc < dyn Fn ( usize , usize ) -> std:: cmp:: Ordering + Send + Sync > ;
303302
304303/// Holds pre-computed comparators for finding RANGE window frame boundaries for all rows in the batch.
@@ -309,7 +308,9 @@ struct WindowRangeComparator {
309308
310309impl std:: fmt:: Debug for WindowRangeComparator {
311310 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
312- f. debug_struct ( "WindowRangeComparator" ) . field ( "comparators" , & self . comparators . len ( ) ) . finish ( )
311+ f. debug_struct ( "WindowRangeComparator" )
312+ . field ( "comparators" , & self . comparators . len ( ) )
313+ . finish ( )
313314 }
314315}
315316
@@ -388,9 +389,9 @@ impl WindowRangeComparator {
388389 }
389390
390391 /// Computes array for a given bound.
391- /// For PRECEDING with descending=false: bound = value + delta
392- /// For PRECEDING with descending=true: bound = value - delta
393- /// For FOLLOWING with descending=false: bound = value + delta
392+ /// For PRECEDING with descending=false: bound = value - delta
393+ /// For PRECEDING with descending=true: bound = value + delta
394+ /// For FOLLOWING with descending=false: bound = value + delta
394395 /// For FOLLOWING with descending=true: bound = value - delta
395396 fn compute_bound_array (
396397 range_column : & ArrayRef ,
@@ -405,11 +406,91 @@ impl WindowRangeComparator {
405406 let result = if add {
406407 add_wrapping ( range_column, & delta_scalar)
407408 } else {
408- sub_wrapping ( range_column, & delta_scalar)
409+ if let Some ( result) = Self :: saturating_sub_unsigned_array ( range_column, delta)
410+ {
411+ Ok ( result)
412+ } else {
413+ sub_wrapping ( range_column, & delta_scalar)
414+ }
409415 } ;
410416 result. map_err ( |e| internal_datafusion_err ! ( "Failed to compute bound array: {e}" ) )
411417 }
412418
419+ fn saturating_sub_unsigned_array (
420+ range_column : & ArrayRef ,
421+ delta : & ScalarValue ,
422+ ) -> Option < ArrayRef > {
423+ match ( range_column. data_type ( ) , delta) {
424+ ( DataType :: UInt8 , ScalarValue :: UInt8 ( Some ( delta) ) ) => {
425+ let result: PrimitiveArray < UInt8Type > = range_column
426+ . as_primitive :: < UInt8Type > ( )
427+ . unary ( |value| value. saturating_sub ( * delta) ) ;
428+ Some ( Arc :: new ( result) )
429+ }
430+ ( DataType :: UInt16 , ScalarValue :: UInt16 ( Some ( delta) ) ) => {
431+ let result: PrimitiveArray < UInt16Type > = range_column
432+ . as_primitive :: < UInt16Type > ( )
433+ . unary ( |value| value. saturating_sub ( * delta) ) ;
434+ Some ( Arc :: new ( result) )
435+ }
436+ ( DataType :: UInt16 , ScalarValue :: UInt8 ( Some ( delta) ) ) => {
437+ let delta = * delta as u16 ;
438+ let result: PrimitiveArray < UInt16Type > = range_column
439+ . as_primitive :: < UInt16Type > ( )
440+ . unary ( |value| value. saturating_sub ( delta) ) ;
441+ Some ( Arc :: new ( result) )
442+ }
443+ ( DataType :: UInt32 , ScalarValue :: UInt32 ( Some ( delta) ) ) => {
444+ let result: PrimitiveArray < UInt32Type > = range_column
445+ . as_primitive :: < UInt32Type > ( )
446+ . unary ( |value| value. saturating_sub ( * delta) ) ;
447+ Some ( Arc :: new ( result) )
448+ }
449+ ( DataType :: UInt32 , ScalarValue :: UInt16 ( Some ( delta) ) ) => {
450+ let delta = * delta as u32 ;
451+ let result: PrimitiveArray < UInt32Type > = range_column
452+ . as_primitive :: < UInt32Type > ( )
453+ . unary ( |value| value. saturating_sub ( delta) ) ;
454+ Some ( Arc :: new ( result) )
455+ }
456+ ( DataType :: UInt32 , ScalarValue :: UInt8 ( Some ( delta) ) ) => {
457+ let delta = * delta as u32 ;
458+ let result: PrimitiveArray < UInt32Type > = range_column
459+ . as_primitive :: < UInt32Type > ( )
460+ . unary ( |value| value. saturating_sub ( delta) ) ;
461+ Some ( Arc :: new ( result) )
462+ }
463+ ( DataType :: UInt64 , ScalarValue :: UInt64 ( Some ( delta) ) ) => {
464+ let result: PrimitiveArray < UInt64Type > = range_column
465+ . as_primitive :: < UInt64Type > ( )
466+ . unary ( |value| value. saturating_sub ( * delta) ) ;
467+ Some ( Arc :: new ( result) )
468+ }
469+ ( DataType :: UInt64 , ScalarValue :: UInt32 ( Some ( delta) ) ) => {
470+ let delta = * delta as u64 ;
471+ let result: PrimitiveArray < UInt64Type > = range_column
472+ . as_primitive :: < UInt64Type > ( )
473+ . unary ( |value| value. saturating_sub ( delta) ) ;
474+ Some ( Arc :: new ( result) )
475+ }
476+ ( DataType :: UInt64 , ScalarValue :: UInt16 ( Some ( delta) ) ) => {
477+ let delta = * delta as u64 ;
478+ let result: PrimitiveArray < UInt64Type > = range_column
479+ . as_primitive :: < UInt64Type > ( )
480+ . unary ( |value| value. saturating_sub ( delta) ) ;
481+ Some ( Arc :: new ( result) )
482+ }
483+ ( DataType :: UInt64 , ScalarValue :: UInt8 ( Some ( delta) ) ) => {
484+ let delta = * delta as u64 ;
485+ let result: PrimitiveArray < UInt64Type > = range_column
486+ . as_primitive :: < UInt64Type > ( )
487+ . unary ( |value| value. saturating_sub ( delta) ) ;
488+ Some ( Arc :: new ( result) )
489+ }
490+ _ => None ,
491+ }
492+ }
493+
413494 fn compare ( & self , search_idx : usize , current_idx : usize ) -> std:: cmp:: Ordering {
414495 for comparator in & self . comparators {
415496 let cmp = comparator ( search_idx, current_idx) ;
@@ -447,8 +528,16 @@ impl WindowFrameStateRange {
447528 window_frame : & Arc < WindowFrame > ,
448529 range_columns : & [ ArrayRef ] ,
449530 ) -> Result < ( ) > {
450- self . start_bound_comparator = WindowRangeComparator :: try_build ( & window_frame. start_bound , range_columns, & self . sort_options ) ?;
451- self . end_bound_comparator = WindowRangeComparator :: try_build ( & window_frame. end_bound , range_columns, & self . sort_options ) ?;
531+ self . start_bound_comparator = WindowRangeComparator :: try_build (
532+ & window_frame. start_bound ,
533+ range_columns,
534+ & self . sort_options ,
535+ ) ?;
536+ self . end_bound_comparator = WindowRangeComparator :: try_build (
537+ & window_frame. end_bound ,
538+ range_columns,
539+ & self . sort_options ,
540+ ) ?;
452541 Ok ( ( ) )
453542 }
454543
@@ -464,10 +553,8 @@ impl WindowFrameStateRange {
464553 WindowFrameBound :: Preceding ( _)
465554 | WindowFrameBound :: CurrentRow
466555 | WindowFrameBound :: Following ( _) => {
467- let comparator = self
468- . start_bound_comparator
469- . as_ref ( )
470- . ok_or_else ( || {
556+ let comparator =
557+ self . start_bound_comparator . as_ref ( ) . ok_or_else ( || {
471558 internal_datafusion_err ! ( "Missing start_bound comparator" )
472559 } ) ?;
473560 self . search_index_of_row :: < true > (
@@ -484,18 +571,10 @@ impl WindowFrameStateRange {
484571 WindowFrameBound :: Preceding ( _)
485572 | WindowFrameBound :: CurrentRow
486573 | WindowFrameBound :: Following ( _) => {
487- let comparator = self
488- . end_bound_comparator
489- . as_ref ( )
490- . ok_or_else ( || {
491- internal_datafusion_err ! ( "Missing end_bound comparator" )
492- } ) ?;
493- self . search_index_of_row :: < false > (
494- comparator,
495- last_range. end ,
496- length,
497- idx,
498- )
574+ let comparator = self . end_bound_comparator . as_ref ( ) . ok_or_else ( || {
575+ internal_datafusion_err ! ( "Missing end_bound comparator" )
576+ } ) ?;
577+ self . search_index_of_row :: < false > ( comparator, last_range. end , length, idx)
499578 }
500579 } ;
501580 Ok ( Range { start, end } )
0 commit comments