diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index 1657cae913fef..1a18ee0b68e69 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -92,6 +92,20 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("window order by and range offsets, aggregate functions", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT \ + MAX(f64) OVER (ORDER BY u64_narrow RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + MIN(f32) OVER (ORDER BY u64_narrow DESC RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + SUM(u64_narrow) OVER (ORDER BY u64_narrow ASC RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING) \ + FROM t", + ) + }) + }); + c.bench_function("window order by, built-in functions", |b| { b.iter(|| { query( @@ -182,6 +196,23 @@ fn criterion_benchmark(c: &mut Criterion) { }, ); + c.bench_function( + "window partition and order by and range offsets, u64_wide, aggregate functions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT \ + MAX(f64) OVER (PARTITION BY u64_wide ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + MIN(f32) OVER (PARTITION BY u64_wide ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + SUM(u64_narrow) OVER (PARTITION BY u64_wide ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING) \ + FROM t", + ) + }) + }, + ); + c.bench_function( "window partition and order by, u64_narrow, aggregate functions", |b| { @@ -199,6 +230,23 @@ fn criterion_benchmark(c: &mut Criterion) { }, ); + c.bench_function( + "window partition and order by and range offsets, u64_narrow, aggregate functions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT \ + MAX(f64) OVER (PARTITION BY u64_narrow ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + MIN(f32) OVER (PARTITION BY u64_narrow ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + SUM(u64_narrow) OVER (PARTITION BY u64_narrow ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING) \ + FROM t", + ) + }) + }, + ); + c.bench_function( "window partition and order by, u64_wide, built-in functions", |b| { @@ -232,6 +280,23 @@ fn criterion_benchmark(c: &mut Criterion) { }) }, ); + + c.bench_function( + "window partition and order by and range offsets, u64_wide, aggregate functions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT \ + MAX(f64) OVER (PARTITION BY u64_wide ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + MIN(f32) OVER (PARTITION BY u64_wide ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING), \ + SUM(u64_narrow) OVER (PARTITION BY u64_wide ORDER by f64 RANGE BETWEEN 5 PRECEDING AND 5 FOLLOWING) \ + FROM t", + ) + }) + }, + ); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index d7da7a778b011..84c1df3a2faf1 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -22,14 +22,17 @@ use std::{collections::VecDeque, ops::Range, sync::Arc}; use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use arrow::{ - array::ArrayRef, - compute::{SortOptions, concat, concat_batches}, - datatypes::{DataType, SchemaRef}, + array::{ArrayRef, AsArray, PrimitiveArray, make_comparator}, + compute::{ + SortOptions, concat, concat_batches, + kernels::numeric::{add_wrapping, sub_wrapping}, + }, + datatypes::{DataType, SchemaRef, UInt8Type, UInt16Type, UInt32Type, UInt64Type}, record_batch::RecordBatch, }; use datafusion_common::{ Result, ScalarValue, internal_datafusion_err, internal_err, - utils::{compare_rows, get_row_at_idx, search_in_slice}, + utils::{get_row_at_idx, search_in_slice}, }; /// Holds the state of evaluating a window function @@ -152,6 +155,21 @@ impl WindowFrameContext { } } + /// Refreshes any cached frame comparators for the current batch. + /// + /// This is a no-op for `ROWS` and `GROUPS` frames. For `RANGE` frames, + /// callers should invoke this before `calculate_range` when the ORDER BY + /// columns change, such as when evaluating a new batch or partition. + pub fn update_comparators(&mut self, range_columns: &[ArrayRef]) -> Result<()> { + match self { + WindowFrameContext::Range { + window_frame, + state, + } => state.update_comparators(window_frame, range_columns), + _ => Ok(()), + } + } + /// This function calculates beginning/ending indices for the frame of the current row. pub fn calculate_range( &mut self, @@ -170,13 +188,7 @@ impl WindowFrameContext { WindowFrameContext::Range { window_frame, state, - } => state.calculate_range( - window_frame, - last_range, - range_columns, - length, - idx, - ), + } => state.calculate_range(window_frame, last_range, length, idx), // Sort options is not used in GROUPS mode calculations as the // inequality of two rows indicates a group change, and ordering // or position of NULLs do not impact inequality. @@ -291,155 +303,305 @@ impl PartitionBatchState { } } +type SharedDynComparator = Arc std::cmp::Ordering + Send + Sync>; + +/// Holds pre-computed comparators for finding RANGE window frame boundaries for all rows in the batch. +#[derive(Clone)] +struct WindowRangeComparator { + comparators: Vec, +} + +impl std::fmt::Debug for WindowRangeComparator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WindowRangeComparator") + .field("comparators", &self.comparators.len()) + .finish() + } +} + +impl WindowRangeComparator { + fn try_build( + bound: &WindowFrameBound, + range_columns: &[ArrayRef], + sort_options: &[SortOptions], + ) -> Result> { + let sort_descending = sort_options.first().map(|o| o.descending).unwrap_or(false); + match bound { + WindowFrameBound::Preceding(delta) if delta.is_null() => { + // UNBOUNDED PRECEDING + Ok(None) + } + WindowFrameBound::Following(delta) if delta.is_null() => { + // UNBOUNDED FOLLOWING + Ok(None) + } + WindowFrameBound::Preceding(delta) => { + let comparators = Self::build_comparator( + range_columns, + Some(delta), + true, + sort_descending, + sort_options, + )?; + Ok(Some(comparators)) + } + WindowFrameBound::Following(delta) => { + let comparators = Self::build_comparator( + range_columns, + Some(delta), + false, + sort_descending, + sort_options, + )?; + Ok(Some(comparators)) + } + WindowFrameBound::CurrentRow => { + let comparators = Self::build_comparator( + range_columns, + None, + false, + sort_descending, + sort_options, + )?; + Ok(Some(comparators)) + } + } + } + + fn build_comparator( + range_columns: &[ArrayRef], + delta: Option<&ScalarValue>, + preceding: bool, + sort_descending: bool, + sort_options: &[SortOptions], + ) -> Result { + let mut comparators: Vec = + Vec::with_capacity(range_columns.len()); + for (col, opt) in range_columns.iter().zip(sort_options.iter()) { + let cmp = match delta { + Some(d) => { + let bound_col = + Self::compute_bound_array(col, d, preceding, sort_descending)?; + make_comparator(col, &bound_col, *opt) + } + None => make_comparator(col, col, *opt), + } + .map_err(|e| internal_datafusion_err!("Failed to create comparator: {e}"))?; + + comparators.push(Arc::from(cmp)); + } + Ok(WindowRangeComparator { comparators }) + } + + /// Computes array for a given bound. + /// For PRECEDING with descending=false: bound = value - delta + /// For PRECEDING with descending=true: bound = value + delta + /// For FOLLOWING with descending=false: bound = value + delta + /// For FOLLOWING with descending=true: bound = value - delta + fn compute_bound_array( + range_column: &ArrayRef, + delta: &ScalarValue, + preceding: bool, + sort_descending: bool, + ) -> Result { + let delta_scalar = delta.to_scalar()?; + let add = preceding == sort_descending; + + // TODO: Handle overflows. + let result = if add { + add_wrapping(range_column, &delta_scalar) + } else { + if let Some(result) = Self::saturating_sub_unsigned_array(range_column, delta) + { + Ok(result) + } else { + sub_wrapping(range_column, &delta_scalar) + } + }; + result.map_err(|e| internal_datafusion_err!("Failed to compute bound array: {e}")) + } + + fn saturating_sub_unsigned_array( + range_column: &ArrayRef, + delta: &ScalarValue, + ) -> Option { + match (range_column.data_type(), delta) { + (DataType::UInt8, ScalarValue::UInt8(Some(delta))) => { + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(*delta)); + Some(Arc::new(result)) + } + (DataType::UInt16, ScalarValue::UInt16(Some(delta))) => { + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(*delta)); + Some(Arc::new(result)) + } + (DataType::UInt16, ScalarValue::UInt8(Some(delta))) => { + let delta = *delta as u16; + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(delta)); + Some(Arc::new(result)) + } + (DataType::UInt32, ScalarValue::UInt32(Some(delta))) => { + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(*delta)); + Some(Arc::new(result)) + } + (DataType::UInt32, ScalarValue::UInt16(Some(delta))) => { + let delta = *delta as u32; + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(delta)); + Some(Arc::new(result)) + } + (DataType::UInt32, ScalarValue::UInt8(Some(delta))) => { + let delta = *delta as u32; + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(delta)); + Some(Arc::new(result)) + } + (DataType::UInt64, ScalarValue::UInt64(Some(delta))) => { + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(*delta)); + Some(Arc::new(result)) + } + (DataType::UInt64, ScalarValue::UInt32(Some(delta))) => { + let delta = *delta as u64; + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(delta)); + Some(Arc::new(result)) + } + (DataType::UInt64, ScalarValue::UInt16(Some(delta))) => { + let delta = *delta as u64; + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(delta)); + Some(Arc::new(result)) + } + (DataType::UInt64, ScalarValue::UInt8(Some(delta))) => { + let delta = *delta as u64; + let result: PrimitiveArray = range_column + .as_primitive::() + .unary(|value| value.saturating_sub(delta)); + Some(Arc::new(result)) + } + _ => None, + } + } + + fn compare(&self, search_idx: usize, current_idx: usize) -> std::cmp::Ordering { + for comparator in &self.comparators { + let cmp = comparator(search_idx, current_idx); + if cmp != std::cmp::Ordering::Equal { + return cmp; + } + } + std::cmp::Ordering::Equal + } +} + /// This structure encapsulates all the state information we require as we scan /// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER /// BY clause. This information is used to calculate the range. +/// Attributes `start_bound_comparator` and `end_bound_comparator` store the cached comparators for calculating ranges. #[derive(Debug, Default, Clone)] pub struct WindowFrameStateRange { sort_options: Vec, + start_bound_comparator: Option, + end_bound_comparator: Option, } impl WindowFrameStateRange { /// Create a new object to store the search state. fn new(sort_options: Vec) -> Self { - Self { sort_options } + Self { + sort_options, + start_bound_comparator: None, + end_bound_comparator: None, + } } - /// This function calculates beginning/ending indices for the frame of the current row. - // Argument `last_range` stores the resulting indices from the previous search. Since the indices only - // advance forward, we start from `last_range` subsequently. Thus, the overall - // time complexity of linear search amortizes to O(n) where n denotes the total - // row count. - fn calculate_range( + fn update_comparators( &mut self, window_frame: &Arc, - last_range: &Range, range_columns: &[ArrayRef], + ) -> Result<()> { + self.start_bound_comparator = WindowRangeComparator::try_build( + &window_frame.start_bound, + range_columns, + &self.sort_options, + )?; + self.end_bound_comparator = WindowRangeComparator::try_build( + &window_frame.end_bound, + range_columns, + &self.sort_options, + )?; + Ok(()) + } + + fn calculate_range( + &self, + window_frame: &Arc, + last_range: &Range, length: usize, idx: usize, ) -> Result> { let start = match window_frame.start_bound { - WindowFrameBound::Preceding(ref n) => { - if n.is_null() { - // UNBOUNDED PRECEDING - 0 - } else { - self.calculate_index_of_row::( - range_columns, - last_range, - idx, - Some(n), - length, - )? - } - } - WindowFrameBound::CurrentRow => self.calculate_index_of_row::( - range_columns, - last_range, - idx, - None, - length, - )?, - WindowFrameBound::Following(ref n) => self - .calculate_index_of_row::( - range_columns, - last_range, - idx, - Some(n), + WindowFrameBound::Preceding(ref n) if n.is_null() => 0, + WindowFrameBound::Preceding(_) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(_) => { + let comparator = + self.start_bound_comparator.as_ref().ok_or_else(|| { + internal_datafusion_err!("Missing start_bound comparator") + })?; + self.search_index_of_row::( + comparator, + last_range.start, length, - )?, + idx, + ) + } }; + let end = match window_frame.end_bound { - WindowFrameBound::Preceding(ref n) => self - .calculate_index_of_row::( - range_columns, - last_range, - idx, - Some(n), - length, - )?, - WindowFrameBound::CurrentRow => self.calculate_index_of_row::( - range_columns, - last_range, - idx, - None, - length, - )?, - WindowFrameBound::Following(ref n) => { - if n.is_null() { - // UNBOUNDED FOLLOWING - length - } else { - self.calculate_index_of_row::( - range_columns, - last_range, - idx, - Some(n), - length, - )? - } + WindowFrameBound::Following(ref n) if n.is_null() => length, + WindowFrameBound::Preceding(_) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(_) => { + let comparator = self.end_bound_comparator.as_ref().ok_or_else(|| { + internal_datafusion_err!("Missing end_bound comparator") + })?; + self.search_index_of_row::(comparator, last_range.end, length, idx) } }; Ok(Range { start, end }) } - /// This function does the heavy lifting when finding range boundaries. It is meant to be - /// called twice, in succession, to get window frame start and end indices (with `SIDE` - /// supplied as true and false, respectively). - fn calculate_index_of_row( - &mut self, - range_columns: &[ArrayRef], - last_range: &Range, - idx: usize, - delta: Option<&ScalarValue>, + fn search_index_of_row( + &self, + comparator: &WindowRangeComparator, + mut search_start: usize, length: usize, - ) -> Result { - let current_row_values = get_row_at_idx(range_columns, idx)?; - let end_range = if let Some(delta) = delta { - let is_descending: bool = self - .sort_options - .first() - .ok_or_else(|| { - internal_datafusion_err!( - "Sort options unexpectedly absent in a window frame" - ) - })? - .descending; - - current_row_values - .iter() - .map(|value| { - if value.is_null() { - return Ok(value.clone()); - } - if SEARCH_SIDE == is_descending { - // TODO: Handle positive overflows. - value.add(delta) - } else if value.is_unsigned() && value < delta { - // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. - // If we decide to implement a "default" construction mechanism for ScalarValue, - // change the following statement to use that. - value.sub(value) - } else { - // TODO: Handle negative overflows. - value.sub(delta) - } - }) - .collect::>>()? - } else { - current_row_values - }; - let search_start = if SIDE { - last_range.start - } else { - last_range.end - }; - let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { - let cmp = compare_rows(current, target, &self.sort_options)?; - Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) - }; - search_in_slice(range_columns, &end_range, compare_fn, search_start, length) + current_idx: usize, + ) -> usize { + while search_start < length { + let cmp = comparator.compare(search_start, current_idx); + let stop = if SIDE { !cmp.is_lt() } else { !cmp.is_le() }; + if stop { + break; + } + search_start += 1; + } + search_start } } @@ -719,6 +881,7 @@ mod tests { let (range_columns, _) = get_test_data(); let n_row = range_columns[0].len(); let mut last_range = Range { start: 0, end: 0 }; + window_frame_context.update_comparators(&range_columns)?; for (idx, expected_range) in expected_results.into_iter().enumerate() { let range = window_frame_context.calculate_range( &range_columns, diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index f8d92d5de4ad5..a289c588719b7 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -128,6 +128,9 @@ impl WindowExpr for StandardWindowExpr { let mut window_frame_ctx = WindowFrameContext::new(Arc::clone(&self.window_frame), sort_options); let mut last_range = Range { start: 0, end: 0 }; + + window_frame_ctx.update_comparators(order_bys_ref)?; + // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { let range = window_frame_ctx.calculate_range( @@ -203,6 +206,19 @@ impl WindowExpr for StandardWindowExpr { } else { evaluator.is_causal() }; + + if evaluator.uses_window_frame() { + state + .window_frame_ctx + .get_or_insert_with(|| { + WindowFrameContext::new( + Arc::clone(&self.window_frame), + sort_options.clone(), + ) + }) + .update_comparators(order_bys_ref)?; + } + for idx in state.last_calculated_index..num_rows { let frame_range = if evaluator.uses_window_frame() { state diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 0f0ec647a50ae..b0b66692640d9 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -335,6 +335,9 @@ pub trait AggregateWindowExpr: WindowExpr { let length = values[0].len(); let mut row_wise_results: Vec = vec![]; let is_causal = self.get_window_frame().is_causal(); + + window_frame_ctx.update_comparators(&order_bys)?; + while idx < length { // Start search from the last_range. This squeezes searched range. let cur_range =