Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions datafusion/physical-plan/src/sorts/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,34 @@ impl BatchBuilder {
self.indices.push((cursor.batch_idx, row_idx));
}

/// Append `count` consecutive rows from `stream_idx`
pub fn push_rows(&mut self, stream_idx: usize, count: usize) {
let cursor = &mut self.cursors[stream_idx];
let batch_idx = cursor.batch_idx;
let start_row = cursor.row_idx;
self.indices
.extend((0..count).map(|i| (batch_idx, start_row + i)));
cursor.row_idx += count;
}

/// Slice the current batch for `stream_idx` starting at its cursor
/// position, returning `num_rows` rows as a zero-copy [`RecordBatch`].
///
/// Advances the builder's cursor but does **not** touch `self.indices`,
/// so the caller must not also call `push_row`/`push_rows` for these
/// rows.
pub fn take_batch_slice(
&mut self,
stream_idx: usize,
num_rows: usize,
) -> RecordBatch {
let cursor = &mut self.cursors[stream_idx];
let (_, batch) = &self.batches[cursor.batch_idx];
let sliced = batch.slice(cursor.row_idx, num_rows);
cursor.row_idx += num_rows;
sliced
}

/// Returns the number of in-progress rows in this [`BatchBuilder`]
pub fn len(&self) -> usize {
self.indices.len()
Expand Down
28 changes: 28 additions & 0 deletions datafusion/physical-plan/src/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,41 @@ impl<T: CursorValues> Cursor<T> {
self.offset == self.values.len()
}

/// Returns true if the cursor is at the start (offset 0)
pub fn is_at_start(&self) -> bool {
self.offset == 0
}

/// Returns the number of remaining rows (including current position)
pub fn remaining(&self) -> usize {
self.values.len() - self.offset
}

/// Advance the cursor, returning the previous row index
pub fn advance(&mut self) -> usize {
let t = self.offset;
self.offset += 1;
t
}

/// Advance the cursor by `n` positions
pub fn advance_by(&mut self, n: usize) {
self.offset += n;
}

/// Compare the last value in this cursor with the current value of `other`.
///
/// Returns [`Ordering::Less`] if the last value of this cursor comes
/// before `other`'s current value in sort order.
pub fn last_cmp(&self, other: &Self) -> Ordering {
T::compare(
&self.values,
self.values.len() - 1,
&other.values,
other.offset,
)
}

pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor<T>>) -> bool {
if self.offset > 0 {
self.is_eq_to_prev_row()
Expand Down
108 changes: 108 additions & 0 deletions datafusion/physical-plan/src/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,72 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
}

let stream_idx = self.loser_tree[0];

// Batch pass-through: when the winner's entire remaining
// batch is strictly less than every other stream's current
// value we can skip per-row loser-tree comparisons.
// Only check at the start of a new batch to amortise the
// O(log K) runner-up lookup.
if self.cursors[stream_idx]
.as_ref()
.is_some_and(|c| c.is_at_start())
&& self.can_batch_pass_through(stream_idx)
{
let remaining = self.cursors[stream_idx].as_ref().unwrap().remaining();
let space_in_batch =
self.batch_size.saturating_sub(self.in_progress.len());
let fetch_remaining = self
.fetch
.map(|f| f.saturating_sub(self.produced + self.in_progress.len()))
.unwrap_or(usize::MAX);
let rows_to_add = remaining.min(space_in_batch).min(fetch_remaining);

if rows_to_add > 0 {
// Zero-copy fast path: emit a batch slice directly when
// the in-progress buffer is empty and we can take the
// entire remaining batch.
if self.in_progress.is_empty() && rows_to_add == remaining {
let batch =
self.in_progress.take_batch_slice(stream_idx, rows_to_add);
self.produced += rows_to_add;

let cursor = self.cursors[stream_idx].as_mut().unwrap();
cursor.advance_by(rows_to_add);
if cursor.is_finished() {
self.prev_cursors[stream_idx] =
self.cursors[stream_idx].take();
}
self.loser_tree_adjusted = false;

if self.fetch_reached() {
self.done = true;
}
return Poll::Ready(Some(Ok(batch)));
}

// Bulk-push path: append all qualifying rows at once,
// avoiding per-row loser-tree work.
self.in_progress.push_rows(stream_idx, rows_to_add);

let cursor = self.cursors[stream_idx].as_mut().unwrap();
cursor.advance_by(rows_to_add);
if cursor.is_finished() {
self.prev_cursors[stream_idx] = self.cursors[stream_idx].take();
}
self.loser_tree_adjusted = false;

if self.fetch_reached() {
self.done = true;
self.drain_in_progress_on_done = true;
} else if self.in_progress.len() < self.batch_size {
continue;
}

return Poll::Ready(self.emit_in_progress_batch().transpose());
}
}

// Normal row-by-row path
if self.advance_cursors(stream_idx) {
self.loser_tree_adjusted = false;
self.in_progress.push_row(stream_idx);
Expand Down Expand Up @@ -341,6 +407,48 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
}
}

/// Walk the loser tree to find the runner-up (second-smallest current
/// value). This is the minimum of the losers along the winner's path
/// from leaf to root. Cost: O(log K).
fn find_runner_up(&self) -> Option<usize> {
let winner = self.loser_tree[0];
let num_streams = self.cursors.len();
let mut runner_up: Option<usize> = None;

let mut node = self.lt_leaf_node_index(winner);
while node != 0 {
let loser = self.loser_tree[node];
if loser < num_streams && self.cursors[loser].is_some() {
runner_up = Some(match runner_up {
None => loser,
Some(current) if self.is_gt(current, loser) => loser,
Some(current) => current,
});
}
node = self.lt_parent_node_index(node);
}
runner_up
}

/// Returns `true` when the winner's entire remaining batch is strictly
/// less than every other stream's current value, meaning those rows can
/// be emitted without per-row loser-tree comparisons.
fn can_batch_pass_through(&self, winner: usize) -> bool {
let winner_cursor = match &self.cursors[winner] {
Some(c) if c.remaining() > 1 => c,
_ => return false,
};

match self.find_runner_up() {
// All other streams exhausted — pass through is safe
None => true,
Some(runner_up) => {
let runner_up_cursor = self.cursors[runner_up].as_ref().unwrap();
winner_cursor.last_cmp(runner_up_cursor).is_lt()
}
}
}

fn fetch_reached(&mut self) -> bool {
self.fetch
.map(|fetch| self.produced + self.in_progress.len() >= fetch)
Expand Down
168 changes: 168 additions & 0 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1595,4 +1595,172 @@ mod tests {

Ok(())
}

async fn _test_merge_sort_by_b(
partitions: &[Vec<RecordBatch>],
exp: &[&str],
context: Arc<TaskContext>,
) {
let schema = partitions[0][0].schema();
let sort: LexOrdering = [PhysicalSortExpr {
expr: col("b", &schema).unwrap(),
options: Default::default(),
}]
.into();
let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
let collected = collect(merge, context).await.unwrap();
assert_batches_eq!(exp, collected.as_slice());
}

/// Test batch pass-through with multiple non-overlapping batches per
/// partition, ensuring cursor advancement and batch cleanup work.
#[tokio::test]
async fn test_batch_pass_through_multi_batch() {
let task_ctx = Arc::new(TaskContext::default());

// Partition 0: two batches [a, b] then [c, d]
let b0a = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef),
])
.unwrap();
let b0b = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["c", "d"])) as ArrayRef),
])
.unwrap();

// Partition 1: two batches [e, f] then [g, h]
let b1a = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![5, 6])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["e", "f"])) as ArrayRef),
])
.unwrap();
let b1b = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![7, 8])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["g", "h"])) as ArrayRef),
])
.unwrap();

_test_merge_sort_by_b(
&[vec![b0a, b0b], vec![b1a, b1b]],
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | a |",
"| 2 | b |",
"| 3 | c |",
"| 4 | d |",
"| 5 | e |",
"| 6 | f |",
"| 7 | g |",
"| 8 | h |",
"+---+---+",
],
task_ctx,
)
.await;
}

/// Test batch pass-through with a fetch limit that cuts through a
/// pass-through batch.
#[tokio::test]
async fn test_batch_pass_through_with_fetch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));

// Partition 0: [a, b, c]
let b0 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)?;

// Partition 1: [x, y, z] — completely non-overlapping
let b1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(StringArray::from(vec!["x", "y", "z"])),
],
)?;

let task_ctx = Arc::new(TaskContext::default());
let sort: LexOrdering = [PhysicalSortExpr {
expr: col("b", &schema)?,
options: Default::default(),
}]
.into();
let exec = TestMemoryExec::try_new_exec(&[vec![b0], vec![b1]], schema, None)?;
let merge =
Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(4)));

let collected = collect(merge, task_ctx).await?;
assert_batches_eq!(
[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | a |",
"| 2 | b |",
"| 3 | c |",
"| 4 | x |",
"+---+---+",
],
collected.as_slice()
);
Ok(())
}

/// Test that the merge is still correct when batches partially overlap
/// (only some partitions qualify for pass-through).
#[tokio::test]
async fn test_batch_pass_through_partial_overlap() {
let task_ctx = Arc::new(TaskContext::default());

// Partition 0: [a, b] — non-overlapping with partition 2
let b0 = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef),
])
.unwrap();

// Partition 1: [b, d] — overlaps with partition 0 at "b"
let b1 = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["b", "d"])) as ArrayRef),
])
.unwrap();

// Partition 2: [f, g] — non-overlapping with everything
let b2 = RecordBatch::try_from_iter(vec![
("a", Arc::new(Int32Array::from(vec![5, 6])) as ArrayRef),
("b", Arc::new(StringArray::from(vec!["f", "g"])) as ArrayRef),
])
.unwrap();

_test_merge_sort_by_b(
&[vec![b0], vec![b1], vec![b2]],
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | a |",
"| 2 | b |",
"| 3 | b |",
"| 4 | d |",
"| 5 | f |",
"| 6 | g |",
"+---+---+",
],
task_ctx,
)
.await;
}
}
Loading