Skip to content

Commit f9cfb0d

Browse files
committed
refactor
1 parent b0c65c1 commit f9cfb0d

3 files changed

Lines changed: 50 additions & 93 deletions

File tree

datafusion/expr/src/window_state.rs

Lines changed: 47 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ impl WindowFrameContext {
155155
}
156156
}
157157

158-
pub fn calculate_bounds(&mut self, range_columns: &[ArrayRef]) -> Result<()> {
158+
pub fn update_comparators(&mut self, range_columns: &[ArrayRef]) -> Result<()> {
159159
match self {
160160
WindowFrameContext::Range {
161161
window_frame,
162162
state,
163-
} => state.calculate_bounds(window_frame, range_columns),
163+
} => state.update_comparators(window_frame, range_columns),
164164
_ => Ok(()),
165165
}
166166
}
@@ -298,62 +298,28 @@ impl PartitionBatchState {
298298
}
299299
}
300300

301+
301302
type SharedDynComparator = Arc<dyn Fn(usize, usize) -> std::cmp::Ordering + Send + Sync>;
302303

303304
/// Holds pre-computed comparators for finding RANGE window frame boundaries for all rows in the batch.
304305
#[derive(Clone)]
305-
pub struct WindowBoundStateRange {
306-
start_bound_comparators: Option<Vec<SharedDynComparator>>,
307-
end_bound_comparators: Option<Vec<SharedDynComparator>>,
306+
struct WindowRangeComparator {
307+
comparators: Vec<SharedDynComparator>,
308308
}
309309

310-
impl std::fmt::Debug for WindowBoundStateRange {
310+
impl std::fmt::Debug for WindowRangeComparator {
311311
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312-
f.debug_struct("WindowBoundStateRange")
313-
.field(
314-
"start_bound_comparators",
315-
&self.start_bound_comparators.as_ref().map(|c| c.len()),
316-
)
317-
.field(
318-
"end_bound_comparators",
319-
&self.end_bound_comparators.as_ref().map(|c| c.len()),
320-
)
321-
.finish()
312+
f.debug_struct("WindowRangeComparator").field("comparators", &self.comparators.len()).finish()
322313
}
323314
}
324315

325-
impl WindowBoundStateRange {
326-
pub fn try_new(
327-
window_frame: &Arc<WindowFrame>,
328-
range_columns: &[ArrayRef],
329-
sort_options: &[SortOptions],
330-
) -> Result<Self> {
331-
let sort_descending = sort_options.first().map(|o| o.descending).unwrap_or(false);
332-
333-
let start_bound_comparators = Self::build_bound_comparators(
334-
&window_frame.start_bound,
335-
range_columns,
336-
sort_options,
337-
sort_descending,
338-
)?;
339-
let end_bound_comparators = Self::build_bound_comparators(
340-
&window_frame.end_bound,
341-
range_columns,
342-
sort_options,
343-
sort_descending,
344-
)?;
345-
Ok(Self {
346-
start_bound_comparators,
347-
end_bound_comparators,
348-
})
349-
}
350-
351-
fn build_bound_comparators(
316+
impl WindowRangeComparator {
317+
fn try_build(
352318
bound: &WindowFrameBound,
353319
range_columns: &[ArrayRef],
354320
sort_options: &[SortOptions],
355-
sort_descending: bool,
356-
) -> Result<Option<Vec<SharedDynComparator>>> {
321+
) -> Result<Option<Self>> {
322+
let sort_descending = sort_options.first().map(|o| o.descending).unwrap_or(false);
357323
match bound {
358324
WindowFrameBound::Preceding(delta) if delta.is_null() => {
359325
// UNBOUNDED PRECEDING
@@ -364,7 +330,7 @@ impl WindowBoundStateRange {
364330
Ok(None)
365331
}
366332
WindowFrameBound::Preceding(delta) => {
367-
let comparators = Self::build_comparators(
333+
let comparators = Self::build_comparator(
368334
range_columns,
369335
Some(delta),
370336
true,
@@ -374,7 +340,7 @@ impl WindowBoundStateRange {
374340
Ok(Some(comparators))
375341
}
376342
WindowFrameBound::Following(delta) => {
377-
let comparators = Self::build_comparators(
343+
let comparators = Self::build_comparator(
378344
range_columns,
379345
Some(delta),
380346
false,
@@ -384,7 +350,7 @@ impl WindowBoundStateRange {
384350
Ok(Some(comparators))
385351
}
386352
WindowFrameBound::CurrentRow => {
387-
let comparators = Self::build_comparators(
353+
let comparators = Self::build_comparator(
388354
range_columns,
389355
None,
390356
false,
@@ -396,13 +362,13 @@ impl WindowBoundStateRange {
396362
}
397363
}
398364

399-
fn build_comparators(
365+
fn build_comparator(
400366
range_columns: &[ArrayRef],
401367
delta: Option<&ScalarValue>,
402368
preceding: bool,
403369
sort_descending: bool,
404370
sort_options: &[SortOptions],
405-
) -> Result<Vec<SharedDynComparator>> {
371+
) -> Result<Self> {
406372
let mut comparators: Vec<SharedDynComparator> =
407373
Vec::with_capacity(range_columns.len());
408374
for (col, opt) in range_columns.iter().zip(sort_options.iter()) {
@@ -418,7 +384,7 @@ impl WindowBoundStateRange {
418384

419385
comparators.push(Arc::from(cmp));
420386
}
421-
Ok(comparators)
387+
Ok(WindowRangeComparator { comparators })
422388
}
423389

424390
/// Computes array for a given bound.
@@ -443,6 +409,16 @@ impl WindowBoundStateRange {
443409
};
444410
result.map_err(|e| internal_datafusion_err!("Failed to compute bound array: {e}"))
445411
}
412+
413+
fn compare(&self, search_idx: usize, current_idx: usize) -> std::cmp::Ordering {
414+
for comparator in &self.comparators {
415+
let cmp = comparator(search_idx, current_idx);
416+
if cmp != std::cmp::Ordering::Equal {
417+
return cmp;
418+
}
419+
}
420+
std::cmp::Ordering::Equal
421+
}
446422
}
447423

448424
/// This structure encapsulates all the state information we require as we scan
@@ -452,29 +428,27 @@ impl WindowBoundStateRange {
452428
#[derive(Debug, Default, Clone)]
453429
pub struct WindowFrameStateRange {
454430
sort_options: Vec<SortOptions>,
455-
bound_state: Option<WindowBoundStateRange>,
431+
start_bound_comparator: Option<WindowRangeComparator>,
432+
end_bound_comparator: Option<WindowRangeComparator>,
456433
}
457434

458435
impl WindowFrameStateRange {
459436
/// Create a new object to store the search state.
460437
fn new(sort_options: Vec<SortOptions>) -> Self {
461438
Self {
462439
sort_options,
463-
bound_state: None,
440+
start_bound_comparator: None,
441+
end_bound_comparator: None,
464442
}
465443
}
466444

467-
fn calculate_bounds(
445+
fn update_comparators(
468446
&mut self,
469447
window_frame: &Arc<WindowFrame>,
470448
range_columns: &[ArrayRef],
471449
) -> Result<()> {
472-
let bound_state = WindowBoundStateRange::try_new(
473-
window_frame,
474-
range_columns,
475-
&self.sort_options,
476-
)?;
477-
self.bound_state = Some(bound_state);
450+
self.start_bound_comparator = WindowRangeComparator::try_build(&window_frame.start_bound, range_columns, &self.sort_options)?;
451+
self.end_bound_comparator = WindowRangeComparator::try_build(&window_frame.end_bound, range_columns, &self.sort_options)?;
478452
Ok(())
479453
}
480454

@@ -485,23 +459,19 @@ impl WindowFrameStateRange {
485459
length: usize,
486460
idx: usize,
487461
) -> Result<Range<usize>> {
488-
let bound_state = self.bound_state.as_ref().ok_or_else(|| {
489-
internal_datafusion_err!("Missing precalculated WindowBoundStateRange")
490-
})?;
491-
492462
let start = match window_frame.start_bound {
493463
WindowFrameBound::Preceding(ref n) if n.is_null() => 0,
494464
WindowFrameBound::Preceding(_)
495465
| WindowFrameBound::CurrentRow
496466
| WindowFrameBound::Following(_) => {
497-
let comparators = bound_state
498-
.start_bound_comparators
467+
let comparator = self
468+
.start_bound_comparator
499469
.as_ref()
500470
.ok_or_else(|| {
501-
internal_datafusion_err!("Missing start_bound comparators")
471+
internal_datafusion_err!("Missing start_bound comparator")
502472
})?;
503473
self.search_index_of_row::<true>(
504-
comparators,
474+
comparator,
505475
last_range.start,
506476
length,
507477
idx,
@@ -514,12 +484,14 @@ impl WindowFrameStateRange {
514484
WindowFrameBound::Preceding(_)
515485
| WindowFrameBound::CurrentRow
516486
| WindowFrameBound::Following(_) => {
517-
let comparators =
518-
bound_state.end_bound_comparators.as_ref().ok_or_else(|| {
519-
internal_datafusion_err!("Missing end_bound comparators")
487+
let comparator = self
488+
.end_bound_comparator
489+
.as_ref()
490+
.ok_or_else(|| {
491+
internal_datafusion_err!("Missing end_bound comparator")
520492
})?;
521493
self.search_index_of_row::<false>(
522-
comparators,
494+
comparator,
523495
last_range.end,
524496
length,
525497
idx,
@@ -531,13 +503,13 @@ impl WindowFrameStateRange {
531503

532504
fn search_index_of_row<const SIDE: bool>(
533505
&self,
534-
comparators: &[SharedDynComparator],
506+
comparator: &WindowRangeComparator,
535507
mut search_start: usize,
536508
length: usize,
537509
current_idx: usize,
538510
) -> usize {
539511
while search_start < length {
540-
let cmp = self.compare_indexes(comparators, search_start, current_idx);
512+
let cmp = comparator.compare(search_start, current_idx);
541513
let stop = if SIDE { !cmp.is_lt() } else { !cmp.is_le() };
542514
if stop {
543515
break;
@@ -546,21 +518,6 @@ impl WindowFrameStateRange {
546518
}
547519
search_start
548520
}
549-
550-
fn compare_indexes(
551-
&self,
552-
comparators: &[SharedDynComparator],
553-
search_idx: usize,
554-
current_idx: usize,
555-
) -> std::cmp::Ordering {
556-
for comparator in comparators {
557-
let cmp = comparator(search_idx, current_idx);
558-
if cmp != std::cmp::Ordering::Equal {
559-
return cmp;
560-
}
561-
}
562-
std::cmp::Ordering::Equal
563-
}
564521
}
565522

566523
// In GROUPS mode, rows with duplicate sorting values are grouped together.
@@ -839,7 +796,7 @@ mod tests {
839796
let (range_columns, _) = get_test_data();
840797
let n_row = range_columns[0].len();
841798
let mut last_range = Range { start: 0, end: 0 };
842-
window_frame_context.calculate_bounds(&range_columns)?;
799+
window_frame_context.update_comparators(&range_columns)?;
843800
for (idx, expected_range) in expected_results.into_iter().enumerate() {
844801
let range = window_frame_context.calculate_range(
845802
&range_columns,

datafusion/physical-expr/src/window/standard.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl WindowExpr for StandardWindowExpr {
129129
WindowFrameContext::new(Arc::clone(&self.window_frame), sort_options);
130130
let mut last_range = Range { start: 0, end: 0 };
131131

132-
window_frame_ctx.calculate_bounds(order_bys_ref)?;
132+
window_frame_ctx.update_comparators(order_bys_ref)?;
133133

134134
// We iterate on each row to calculate window frame range and and window function result
135135
for idx in 0..num_rows {
@@ -216,7 +216,7 @@ impl WindowExpr for StandardWindowExpr {
216216
sort_options.clone(),
217217
)
218218
})
219-
.calculate_bounds(order_bys_ref)?;
219+
.update_comparators(order_bys_ref)?;
220220
}
221221

222222
for idx in state.last_calculated_index..num_rows {

datafusion/physical-expr/src/window/window_expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ pub trait AggregateWindowExpr: WindowExpr {
336336
let mut row_wise_results: Vec<ScalarValue> = vec![];
337337
let is_causal = self.get_window_frame().is_causal();
338338

339-
window_frame_ctx.calculate_bounds(&order_bys)?;
339+
window_frame_ctx.update_comparators(&order_bys)?;
340340

341341
while idx < length {
342342
// Start search from the last_range. This squeezes searched range.

0 commit comments

Comments
 (0)