Skip to content

Commit a6a4df9

Browse files
authored
Fix memory reservation starvation in sort-merge (#20642)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> This PR fixes memory reservation starvation in sort-merge when multiple sort partitions share a GreedyMemoryPool. When multiple `ExternalSorter` instances run concurrently and share a single memory pool, the merge phase starves: 1. Each partition pre-reserves sort_spill_reservation_bytes via merge_reservation 2. When entering the merge phase, new_empty() was used to create a new reservation starting at 0 bytes, while the pre-reserved bytes sat idle in ExternalSorter.merge_reservation 3. Those freed bytes were immediately consumed by other partitions racing for memory 4. The merge could no longer allocate memory from the pool → OOM / starvation ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ~~I can't find a deterministic way to reproduce the bug, but it occurs in our production.~~ Add an end-to-end test to verify the fix ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent b7a3f53 commit a6a4df9

3 files changed

Lines changed: 239 additions & 22 deletions

File tree

datafusion/physical-plan/src/sorts/builder.rs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,24 @@ pub struct BatchBuilder {
4040
/// Maintain a list of [`RecordBatch`] and their corresponding stream
4141
batches: Vec<(usize, RecordBatch)>,
4242

43-
/// Accounts for memory used by buffered batches
43+
/// Accounts for memory used by buffered batches.
44+
///
45+
/// May include pre-reserved bytes (from `sort_spill_reservation_bytes`)
46+
/// that were transferred via [`MemoryReservation::take()`] to prevent
47+
/// starvation when concurrent sort partitions compete for pool memory.
4448
reservation: MemoryReservation,
4549

50+
/// Tracks the actual memory used by buffered batches (not including
51+
/// pre-reserved bytes). This allows [`Self::push_batch`] to skip pool
52+
/// allocation requests when the pre-reserved bytes cover the batch.
53+
batches_mem_used: usize,
54+
55+
/// The initial reservation size at construction time. When the reservation
56+
/// is pre-loaded with `sort_spill_reservation_bytes` (via `take()`), this
57+
/// records that amount so we never shrink below it, maintaining the
58+
/// anti-starvation guarantee throughout the merge.
59+
initial_reservation: usize,
60+
4661
/// The current [`BatchCursor`] for each stream
4762
cursors: Vec<BatchCursor>,
4863

@@ -59,19 +74,26 @@ impl BatchBuilder {
5974
batch_size: usize,
6075
reservation: MemoryReservation,
6176
) -> Self {
77+
let initial_reservation = reservation.size();
6278
Self {
6379
schema,
6480
batches: Vec::with_capacity(stream_count * 2),
6581
cursors: vec![BatchCursor::default(); stream_count],
6682
indices: Vec::with_capacity(batch_size),
6783
reservation,
84+
batches_mem_used: 0,
85+
initial_reservation,
6886
}
6987
}
7088

7189
/// Append a new batch in `stream_idx`
7290
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
73-
self.reservation
74-
.try_grow(get_record_batch_memory_size(&batch))?;
91+
let size = get_record_batch_memory_size(&batch);
92+
self.batches_mem_used += size;
93+
// Only request additional memory from the pool when actual batch
94+
// usage exceeds the current reservation (which may include
95+
// pre-reserved bytes from sort_spill_reservation_bytes).
96+
try_grow_reservation_to_at_least(&mut self.reservation, self.batches_mem_used)?;
7597
let batch_idx = self.batches.len();
7698
self.batches.push((stream_idx, batch));
7799
self.cursors[stream_idx] = BatchCursor {
@@ -143,14 +165,38 @@ impl BatchBuilder {
143165
stream_cursor.batch_idx = retained;
144166
retained += 1;
145167
} else {
146-
self.reservation.shrink(get_record_batch_memory_size(batch));
168+
self.batches_mem_used -= get_record_batch_memory_size(batch);
147169
}
148170
retain
149171
});
150172

173+
// Release excess memory back to the pool, but never shrink below
174+
// initial_reservation to maintain the anti-starvation guarantee
175+
// for the merge phase.
176+
let target = self.batches_mem_used.max(self.initial_reservation);
177+
if self.reservation.size() > target {
178+
self.reservation.shrink(self.reservation.size() - target);
179+
}
180+
151181
Ok(Some(RecordBatch::try_new(
152182
Arc::clone(&self.schema),
153183
columns,
154184
)?))
155185
}
156186
}
187+
188+
/// Try to grow `reservation` so it covers at least `needed` bytes.
189+
///
190+
/// When a reservation has been pre-loaded with bytes (e.g. via
191+
/// [`MemoryReservation::take()`]), this avoids redundant pool
192+
/// allocations: if the reservation already covers `needed`, this is
193+
/// a no-op; otherwise only the deficit is requested from the pool.
194+
pub(crate) fn try_grow_reservation_to_at_least(
195+
reservation: &mut MemoryReservation,
196+
needed: usize,
197+
) -> Result<()> {
198+
if needed > reservation.size() {
199+
reservation.try_grow(needed - reservation.size())?;
200+
}
201+
Ok(())
202+
}

datafusion/physical-plan/src/sorts/multi_level_merge.rs

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use arrow::datatypes::SchemaRef;
3030
use datafusion_common::Result;
3131
use datafusion_execution::memory_pool::MemoryReservation;
3232

33+
use crate::sorts::builder::try_grow_reservation_to_at_least;
3334
use crate::sorts::sort::get_reserved_bytes_for_record_batch_size;
3435
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
3536
use crate::stream::RecordBatchStreamAdapter;
@@ -253,7 +254,12 @@ impl MultiLevelMergeBuilder {
253254

254255
// Need to merge multiple streams
255256
(_, _) => {
256-
let mut memory_reservation = self.reservation.new_empty();
257+
// Transfer any pre-reserved bytes (from sort_spill_reservation_bytes)
258+
// to the merge memory reservation. This prevents starvation when
259+
// concurrent sort partitions compete for pool memory: the pre-reserved
260+
// bytes cover spill file buffer reservations without additional pool
261+
// allocation.
262+
let mut memory_reservation = self.reservation.take();
257263

258264
// Don't account for existing streams memory
259265
// as we are not holding the memory for them
@@ -269,6 +275,15 @@ impl MultiLevelMergeBuilder {
269275

270276
let is_only_merging_memory_streams = sorted_spill_files.is_empty();
271277

278+
// If no spill files were selected (e.g. all too large for
279+
// available memory but enough in-memory streams exist),
280+
// return the pre-reserved bytes to self.reservation so
281+
// create_new_merge_sort can transfer them to the merge
282+
// stream's BatchBuilder.
283+
if is_only_merging_memory_streams {
284+
mem::swap(&mut self.reservation, &mut memory_reservation);
285+
}
286+
272287
for spill in sorted_spill_files {
273288
let stream = self
274289
.spill_manager
@@ -337,8 +352,10 @@ impl MultiLevelMergeBuilder {
337352
builder = builder.with_bypass_mempool();
338353
} else {
339354
// If we are only merging in-memory streams, we need to use the memory reservation
340-
// because we don't know the maximum size of the batches in the streams
341-
builder = builder.with_reservation(self.reservation.new_empty());
355+
// because we don't know the maximum size of the batches in the streams.
356+
// Use take() to transfer any pre-reserved bytes so the merge can use them
357+
// as its initial budget without additional pool allocation.
358+
builder = builder.with_reservation(self.reservation.take());
342359
}
343360

344361
builder.build()
@@ -356,17 +373,24 @@ impl MultiLevelMergeBuilder {
356373
) -> Result<(Vec<SortedSpillFile>, usize)> {
357374
assert_ne!(buffer_len, 0, "Buffer length must be greater than 0");
358375
let mut number_of_spills_to_read_for_current_phase = 0;
376+
// Track total memory needed for spill file buffers. When the
377+
// reservation has pre-reserved bytes (from sort_spill_reservation_bytes),
378+
// those bytes cover the first N spill files without additional pool
379+
// allocation, preventing starvation under memory pressure.
380+
let mut total_needed: usize = 0;
359381

360382
for spill in &self.sorted_spill_files {
361-
// For memory pools that are not shared this is good, for other this is not
362-
// and there should be some upper limit to memory reservation so we won't starve the system
363-
match reservation.try_grow(
364-
get_reserved_bytes_for_record_batch_size(
365-
spill.max_record_batch_memory,
366-
// Size will be the same as the sliced size, bc it is a spilled batch.
367-
spill.max_record_batch_memory,
368-
) * buffer_len,
369-
) {
383+
let per_spill = get_reserved_bytes_for_record_batch_size(
384+
spill.max_record_batch_memory,
385+
// Size will be the same as the sliced size, bc it is a spilled batch.
386+
spill.max_record_batch_memory,
387+
) * buffer_len;
388+
total_needed += per_spill;
389+
390+
// For memory pools that are not shared this is good, for other
391+
// this is not and there should be some upper limit to memory
392+
// reservation so we won't starve the system.
393+
match try_grow_reservation_to_at_least(reservation, total_needed) {
370394
Ok(_) => {
371395
number_of_spills_to_read_for_current_phase += 1;
372396
}

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 153 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ impl ExternalSorter {
342342
/// 2. A combined streaming merge incorporating both in-memory
343343
/// batches and data from spill files on disk.
344344
async fn sort(&mut self) -> Result<SendableRecordBatchStream> {
345-
// Release the memory reserved for merge back to the pool so
346-
// there is some left when `in_mem_sort_stream` requests an
347-
// allocation.
348-
self.merge_reservation.free();
349-
350345
if self.spilled_before() {
351346
// Sort `in_mem_batches` and spill it first. If there are many
352347
// `in_mem_batches` and the memory limit is almost reached, merging
@@ -355,6 +350,13 @@ impl ExternalSorter {
355350
self.sort_and_spill_in_mem_batches().await?;
356351
}
357352

353+
// Transfer the pre-reserved merge memory to the streaming merge
354+
// using `take()` instead of `new_empty()`. This ensures the merge
355+
// stream starts with `sort_spill_reservation_bytes` already
356+
// allocated, preventing starvation when concurrent sort partitions
357+
// compete for pool memory. `take()` moves the bytes atomically
358+
// without releasing them back to the pool, so other partitions
359+
// cannot race to consume the freed memory.
358360
StreamingMergeBuilder::new()
359361
.with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files))
360362
.with_spill_manager(self.spill_manager.clone())
@@ -363,9 +365,14 @@ impl ExternalSorter {
363365
.with_metrics(self.metrics.baseline.clone())
364366
.with_batch_size(self.batch_size)
365367
.with_fetch(None)
366-
.with_reservation(self.merge_reservation.new_empty())
368+
.with_reservation(self.merge_reservation.take())
367369
.build()
368370
} else {
371+
// Release the memory reserved for merge back to the pool so
372+
// there is some left when `in_mem_sort_stream` requests an
373+
// allocation. Only needed for the non-spill path; the spill
374+
// path transfers the reservation to the merge stream instead.
375+
self.merge_reservation.free();
369376
self.in_mem_sort_stream(self.metrics.baseline.clone())
370377
}
371378
}
@@ -375,6 +382,12 @@ impl ExternalSorter {
375382
self.reservation.size()
376383
}
377384

385+
/// How much memory is reserved for the merge phase?
386+
#[cfg(test)]
387+
fn merge_reservation_size(&self) -> usize {
388+
self.merge_reservation.size()
389+
}
390+
378391
/// How many bytes have been spilled to disk?
379392
fn spilled_bytes(&self) -> usize {
380393
self.metrics.spill_metrics.spilled_bytes.value()
@@ -2716,4 +2729,138 @@ mod tests {
27162729

27172730
Ok(())
27182731
}
2732+
2733+
/// Verifies that `ExternalSorter::sort()` transfers the pre-reserved
2734+
/// merge bytes to the merge stream via `take()`, rather than leaving
2735+
/// them in the sorter (via `new_empty()`).
2736+
///
2737+
/// 1. Create a sorter with a tight memory pool and insert enough data
2738+
/// to force spilling
2739+
/// 2. Verify `merge_reservation` holds the pre-reserved bytes before sort
2740+
/// 3. Call `sort()` to get the merge stream
2741+
/// 4. Verify `merge_reservation` is now 0 (bytes transferred to merge stream)
2742+
/// 5. Simulate contention: a competing consumer grabs all available pool memory
2743+
/// 6. Verify the merge stream still works (it uses its pre-reserved bytes
2744+
/// as initial budget, not requesting from pool starting at 0)
2745+
///
2746+
/// With `new_empty()` (before fix), step 4 fails: `merge_reservation`
2747+
/// still holds the bytes, the merge stream starts with 0 budget, and
2748+
/// those bytes become unaccounted-for reserved memory that nobody uses.
2749+
#[tokio::test]
2750+
async fn test_sort_merge_reservation_transferred_not_freed() -> Result<()> {
2751+
use datafusion_execution::memory_pool::{
2752+
GreedyMemoryPool, MemoryConsumer, MemoryPool,
2753+
};
2754+
use futures::TryStreamExt;
2755+
2756+
let sort_spill_reservation_bytes: usize = 10 * 1024; // 10 KB
2757+
2758+
// Pool: merge reservation (10KB) + enough room for sort to work.
2759+
// The room must accommodate batch data accumulation before spilling.
2760+
let sort_working_memory: usize = 40 * 1024; // 40 KB for sort operations
2761+
let pool_size = sort_spill_reservation_bytes + sort_working_memory;
2762+
let pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(pool_size));
2763+
2764+
let runtime = RuntimeEnvBuilder::new()
2765+
.with_memory_pool(Arc::clone(&pool))
2766+
.build_arc()?;
2767+
2768+
let metrics_set = ExecutionPlanMetricsSet::new();
2769+
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
2770+
2771+
let mut sorter = ExternalSorter::new(
2772+
0,
2773+
Arc::clone(&schema),
2774+
[PhysicalSortExpr::new_default(Arc::new(Column::new("x", 0)))].into(),
2775+
128, // batch_size
2776+
sort_spill_reservation_bytes,
2777+
usize::MAX, // sort_in_place_threshold_bytes (high to avoid concat path)
2778+
SpillCompression::Uncompressed,
2779+
&metrics_set,
2780+
Arc::clone(&runtime),
2781+
)?;
2782+
2783+
// Insert enough data to force spilling.
2784+
let num_batches = 200;
2785+
for i in 0..num_batches {
2786+
let values: Vec<i32> = ((i * 100)..((i + 1) * 100)).rev().collect();
2787+
let batch = RecordBatch::try_new(
2788+
Arc::clone(&schema),
2789+
vec![Arc::new(Int32Array::from(values))],
2790+
)?;
2791+
sorter.insert_batch(batch).await?;
2792+
}
2793+
2794+
assert!(
2795+
sorter.spilled_before(),
2796+
"Test requires spilling to exercise the merge path"
2797+
);
2798+
2799+
// Before sort(), merge_reservation holds sort_spill_reservation_bytes.
2800+
assert!(
2801+
sorter.merge_reservation_size() >= sort_spill_reservation_bytes,
2802+
"merge_reservation should hold the pre-reserved bytes before sort()"
2803+
);
2804+
2805+
// Call sort() to get the merge stream. With the fix (take()),
2806+
// the pre-reserved merge bytes are transferred to the merge
2807+
// stream. Without the fix (free() + new_empty()), the bytes
2808+
// are released back to the pool and the merge stream starts
2809+
// with 0 bytes.
2810+
let merge_stream = sorter.sort().await?;
2811+
2812+
// THE KEY ASSERTION: after sort(), merge_reservation must be 0.
2813+
// This proves take() transferred the bytes to the merge stream,
2814+
// rather than them being freed back to the pool where other
2815+
// partitions could steal them.
2816+
assert_eq!(
2817+
sorter.merge_reservation_size(),
2818+
0,
2819+
"After sort(), merge_reservation should be 0 (bytes transferred \
2820+
to merge stream via take()). If non-zero, the bytes are still \
2821+
held by the sorter and will be freed on drop, allowing other \
2822+
partitions to steal them."
2823+
);
2824+
2825+
// Drop the sorter to free its reservations back to the pool.
2826+
drop(sorter);
2827+
2828+
// Simulate contention: another partition grabs ALL available
2829+
// pool memory. If the merge stream didn't receive the
2830+
// pre-reserved bytes via take(), it will fail when it tries
2831+
// to allocate memory for reading spill files.
2832+
let contender = MemoryConsumer::new("CompetingPartition").register(&pool);
2833+
let available = pool_size.saturating_sub(pool.reserved());
2834+
if available > 0 {
2835+
contender.try_grow(available).unwrap();
2836+
}
2837+
2838+
// The merge stream must still produce correct results despite
2839+
// the pool being fully consumed by the contender. This only
2840+
// works if sort() transferred the pre-reserved bytes to the
2841+
// merge stream (via take()) rather than freeing them.
2842+
let batches: Vec<RecordBatch> = merge_stream.try_collect().await?;
2843+
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
2844+
assert_eq!(
2845+
total_rows,
2846+
(num_batches * 100) as usize,
2847+
"Merge stream should produce all rows even under memory contention"
2848+
);
2849+
2850+
// Verify data is sorted
2851+
let merged = concat_batches(&schema, &batches)?;
2852+
let col = merged.column(0).as_primitive::<Int32Type>();
2853+
for i in 1..col.len() {
2854+
assert!(
2855+
col.value(i - 1) <= col.value(i),
2856+
"Output should be sorted, but found {} > {} at index {}",
2857+
col.value(i - 1),
2858+
col.value(i),
2859+
i
2860+
);
2861+
}
2862+
2863+
drop(contender);
2864+
Ok(())
2865+
}
27192866
}

0 commit comments

Comments
 (0)