From 4feba229a06208b62b5f9319bd59b22f0cb7ebd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 3 Apr 2026 12:34:31 +0200 Subject: [PATCH] Add batch pass-through optimization to SortPreservingMergeExec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the loser tree winner's entire remaining batch is strictly less than every other stream's current value, skip per-row loser-tree comparisons and emit the batch directly. Two fast paths: - Zero-copy: when in_progress buffer is empty and the full batch qualifies, slice and return the RecordBatch without interleave - Bulk-push: otherwise append all qualifying rows at once, avoiding O(remaining × log K) loser-tree work The runner-up is found by walking the winner's loser-tree path (O(log K)), and the check is only performed at the start of each new batch to amortise cost. Co-Authored-By: Claude Opus 4.6 (1M context) --- datafusion/physical-plan/src/sorts/builder.rs | 28 +++ datafusion/physical-plan/src/sorts/cursor.rs | 28 +++ datafusion/physical-plan/src/sorts/merge.rs | 108 +++++++++++ .../src/sorts/sort_preserving_merge.rs | 168 ++++++++++++++++++ 4 files changed, 332 insertions(+) diff --git a/datafusion/physical-plan/src/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs index 73386212a2a90..c004ed1f3211b 100644 --- a/datafusion/physical-plan/src/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -116,6 +116,34 @@ impl BatchBuilder { self.indices.push((cursor.batch_idx, row_idx)); } + /// Append `count` consecutive rows from `stream_idx` + pub fn push_rows(&mut self, stream_idx: usize, count: usize) { + let cursor = &mut self.cursors[stream_idx]; + let batch_idx = cursor.batch_idx; + let start_row = cursor.row_idx; + self.indices + .extend((0..count).map(|i| (batch_idx, start_row + i))); + cursor.row_idx += count; + } + + /// Slice the current batch for `stream_idx` starting at its cursor + /// position, returning `num_rows` rows as a zero-copy [`RecordBatch`]. + /// + /// Advances the builder's cursor but does **not** touch `self.indices`, + /// so the caller must not also call `push_row`/`push_rows` for these + /// rows. + pub fn take_batch_slice( + &mut self, + stream_idx: usize, + num_rows: usize, + ) -> RecordBatch { + let cursor = &mut self.cursors[stream_idx]; + let (_, batch) = &self.batches[cursor.batch_idx]; + let sliced = batch.slice(cursor.row_idx, num_rows); + cursor.row_idx += num_rows; + sliced + } + /// Returns the number of in-progress rows in this [`BatchBuilder`] pub fn len(&self) -> usize { self.indices.len() diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index 288ec4cee1594..efe3915e87b24 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -93,6 +93,16 @@ impl Cursor { self.offset == self.values.len() } + /// Returns true if the cursor is at the start (offset 0) + pub fn is_at_start(&self) -> bool { + self.offset == 0 + } + + /// Returns the number of remaining rows (including current position) + pub fn remaining(&self) -> usize { + self.values.len() - self.offset + } + /// Advance the cursor, returning the previous row index pub fn advance(&mut self) -> usize { let t = self.offset; @@ -100,6 +110,24 @@ impl Cursor { t } + /// Advance the cursor by `n` positions + pub fn advance_by(&mut self, n: usize) { + self.offset += n; + } + + /// Compare the last value in this cursor with the current value of `other`. + /// + /// Returns [`Ordering::Less`] if the last value of this cursor comes + /// before `other`'s current value in sort order. + pub fn last_cmp(&self, other: &Self) -> Ordering { + T::compare( + &self.values, + self.values.len() - 1, + &other.values, + other.offset, + ) + } + pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor>) -> bool { if self.offset > 0 { self.is_eq_to_prev_row() diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index c29933535adc5..699f392ede5ec 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -302,6 +302,72 @@ impl SortPreservingMergeStream { } let stream_idx = self.loser_tree[0]; + + // Batch pass-through: when the winner's entire remaining + // batch is strictly less than every other stream's current + // value we can skip per-row loser-tree comparisons. + // Only check at the start of a new batch to amortise the + // O(log K) runner-up lookup. + if self.cursors[stream_idx] + .as_ref() + .is_some_and(|c| c.is_at_start()) + && self.can_batch_pass_through(stream_idx) + { + let remaining = self.cursors[stream_idx].as_ref().unwrap().remaining(); + let space_in_batch = + self.batch_size.saturating_sub(self.in_progress.len()); + let fetch_remaining = self + .fetch + .map(|f| f.saturating_sub(self.produced + self.in_progress.len())) + .unwrap_or(usize::MAX); + let rows_to_add = remaining.min(space_in_batch).min(fetch_remaining); + + if rows_to_add > 0 { + // Zero-copy fast path: emit a batch slice directly when + // the in-progress buffer is empty and we can take the + // entire remaining batch. + if self.in_progress.is_empty() && rows_to_add == remaining { + let batch = + self.in_progress.take_batch_slice(stream_idx, rows_to_add); + self.produced += rows_to_add; + + let cursor = self.cursors[stream_idx].as_mut().unwrap(); + cursor.advance_by(rows_to_add); + if cursor.is_finished() { + self.prev_cursors[stream_idx] = + self.cursors[stream_idx].take(); + } + self.loser_tree_adjusted = false; + + if self.fetch_reached() { + self.done = true; + } + return Poll::Ready(Some(Ok(batch))); + } + + // Bulk-push path: append all qualifying rows at once, + // avoiding per-row loser-tree work. + self.in_progress.push_rows(stream_idx, rows_to_add); + + let cursor = self.cursors[stream_idx].as_mut().unwrap(); + cursor.advance_by(rows_to_add); + if cursor.is_finished() { + self.prev_cursors[stream_idx] = self.cursors[stream_idx].take(); + } + self.loser_tree_adjusted = false; + + if self.fetch_reached() { + self.done = true; + self.drain_in_progress_on_done = true; + } else if self.in_progress.len() < self.batch_size { + continue; + } + + return Poll::Ready(self.emit_in_progress_batch().transpose()); + } + } + + // Normal row-by-row path if self.advance_cursors(stream_idx) { self.loser_tree_adjusted = false; self.in_progress.push_row(stream_idx); @@ -341,6 +407,48 @@ impl SortPreservingMergeStream { } } + /// Walk the loser tree to find the runner-up (second-smallest current + /// value). This is the minimum of the losers along the winner's path + /// from leaf to root. Cost: O(log K). + fn find_runner_up(&self) -> Option { + let winner = self.loser_tree[0]; + let num_streams = self.cursors.len(); + let mut runner_up: Option = None; + + let mut node = self.lt_leaf_node_index(winner); + while node != 0 { + let loser = self.loser_tree[node]; + if loser < num_streams && self.cursors[loser].is_some() { + runner_up = Some(match runner_up { + None => loser, + Some(current) if self.is_gt(current, loser) => loser, + Some(current) => current, + }); + } + node = self.lt_parent_node_index(node); + } + runner_up + } + + /// Returns `true` when the winner's entire remaining batch is strictly + /// less than every other stream's current value, meaning those rows can + /// be emitted without per-row loser-tree comparisons. + fn can_batch_pass_through(&self, winner: usize) -> bool { + let winner_cursor = match &self.cursors[winner] { + Some(c) if c.remaining() > 1 => c, + _ => return false, + }; + + match self.find_runner_up() { + // All other streams exhausted — pass through is safe + None => true, + Some(runner_up) => { + let runner_up_cursor = self.cursors[runner_up].as_ref().unwrap(); + winner_cursor.last_cmp(runner_up_cursor).is_lt() + } + } + } + fn fetch_reached(&mut self) -> bool { self.fetch .map(|fetch| self.produced + self.in_progress.len() >= fetch) diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 1e60c391f50d1..8e3666cf24a5c 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -1595,4 +1595,172 @@ mod tests { Ok(()) } + + async fn _test_merge_sort_by_b( + partitions: &[Vec], + exp: &[&str], + context: Arc, + ) { + let schema = partitions[0][0].schema(); + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }] + .into(); + let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); + let collected = collect(merge, context).await.unwrap(); + assert_batches_eq!(exp, collected.as_slice()); + } + + /// Test batch pass-through with multiple non-overlapping batches per + /// partition, ensuring cursor advancement and batch cleanup work. + #[tokio::test] + async fn test_batch_pass_through_multi_batch() { + let task_ctx = Arc::new(TaskContext::default()); + + // Partition 0: two batches [a, b] then [c, d] + let b0a = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef), + ]) + .unwrap(); + let b0b = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["c", "d"])) as ArrayRef), + ]) + .unwrap(); + + // Partition 1: two batches [e, f] then [g, h] + let b1a = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![5, 6])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["e", "f"])) as ArrayRef), + ]) + .unwrap(); + let b1b = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![7, 8])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["g", "h"])) as ArrayRef), + ]) + .unwrap(); + + _test_merge_sort_by_b( + &[vec![b0a, b0b], vec![b1a, b1b]], + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | a |", + "| 2 | b |", + "| 3 | c |", + "| 4 | d |", + "| 5 | e |", + "| 6 | f |", + "| 7 | g |", + "| 8 | h |", + "+---+---+", + ], + task_ctx, + ) + .await; + } + + /// Test batch pass-through with a fetch limit that cuts through a + /// pass-through batch. + #[tokio::test] + async fn test_batch_pass_through_with_fetch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + + // Partition 0: [a, b, c] + let b0 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + // Partition 1: [x, y, z] — completely non-overlapping + let b1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(StringArray::from(vec!["x", "y", "z"])), + ], + )?; + + let task_ctx = Arc::new(TaskContext::default()); + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("b", &schema)?, + options: Default::default(), + }] + .into(); + let exec = TestMemoryExec::try_new_exec(&[vec![b0], vec![b1]], schema, None)?; + let merge = + Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(4))); + + let collected = collect(merge, task_ctx).await?; + assert_batches_eq!( + [ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | a |", + "| 2 | b |", + "| 3 | c |", + "| 4 | x |", + "+---+---+", + ], + collected.as_slice() + ); + Ok(()) + } + + /// Test that the merge is still correct when batches partially overlap + /// (only some partitions qualify for pass-through). + #[tokio::test] + async fn test_batch_pass_through_partial_overlap() { + let task_ctx = Arc::new(TaskContext::default()); + + // Partition 0: [a, b] — non-overlapping with partition 2 + let b0 = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef), + ]) + .unwrap(); + + // Partition 1: [b, d] — overlaps with partition 0 at "b" + let b1 = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["b", "d"])) as ArrayRef), + ]) + .unwrap(); + + // Partition 2: [f, g] — non-overlapping with everything + let b2 = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(Int32Array::from(vec![5, 6])) as ArrayRef), + ("b", Arc::new(StringArray::from(vec!["f", "g"])) as ArrayRef), + ]) + .unwrap(); + + _test_merge_sort_by_b( + &[vec![b0], vec![b1], vec![b2]], + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | a |", + "| 2 | b |", + "| 3 | b |", + "| 4 | d |", + "| 5 | f |", + "| 6 | g |", + "+---+---+", + ], + task_ctx, + ) + .await; + } }