Skip to content

Commit bcd42b0

Browse files
authored
fix: Unaccounted spill sort in row_hash (#20314)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20313 . ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> We must not use that much memory without reserving it. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Added a reservation before the sort, made a shrink call for the group values after the emit and updated the reservation so the reservation will be possible. Moved the sort to use sort_chunked so we can immediately drop the original batch and shrink the reservation to the used sizes, added a new spill method for iterators, so we can use an accurate memory accounting. If said reservation did not succeed, fallback to an incrementing sort method which holds the original batch the whole time, and outputs one batch at the time, this requires a much smaller reservation. Made the reservation much more robust(otherwise the fuzz tests were failing now that we actually reserve the memory in the sort) ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Current tests should still function, but memory should be reserved. Added test that specifically verifies that we error on this when we shouldn't do the sort. Modified the tests that used to test the splitting function in the spill to test the new iter spilling function ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> No
1 parent 33b86fe commit bcd42b0

7 files changed

Lines changed: 397 additions & 89 deletions

File tree

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,6 +3837,142 @@ mod tests {
38373837
Ok(())
38383838
}
38393839

3840+
/// Tests that when the memory pool is too small to accommodate the sort
3841+
/// reservation during spill, the error is properly propagated as
3842+
/// ResourcesExhausted rather than silently exceeding memory limits.
3843+
#[tokio::test]
3844+
async fn test_sort_reservation_fails_during_spill() -> Result<()> {
3845+
let schema = Arc::new(Schema::new(vec![
3846+
Field::new("g", DataType::Int64, false),
3847+
Field::new("a", DataType::Float64, false),
3848+
Field::new("b", DataType::Float64, false),
3849+
Field::new("c", DataType::Float64, false),
3850+
Field::new("d", DataType::Float64, false),
3851+
Field::new("e", DataType::Float64, false),
3852+
]));
3853+
3854+
let batches = vec![vec![
3855+
RecordBatch::try_new(
3856+
Arc::clone(&schema),
3857+
vec![
3858+
Arc::new(Int64Array::from(vec![1])),
3859+
Arc::new(Float64Array::from(vec![10.0])),
3860+
Arc::new(Float64Array::from(vec![20.0])),
3861+
Arc::new(Float64Array::from(vec![30.0])),
3862+
Arc::new(Float64Array::from(vec![40.0])),
3863+
Arc::new(Float64Array::from(vec![50.0])),
3864+
],
3865+
)?,
3866+
RecordBatch::try_new(
3867+
Arc::clone(&schema),
3868+
vec![
3869+
Arc::new(Int64Array::from(vec![2])),
3870+
Arc::new(Float64Array::from(vec![11.0])),
3871+
Arc::new(Float64Array::from(vec![21.0])),
3872+
Arc::new(Float64Array::from(vec![31.0])),
3873+
Arc::new(Float64Array::from(vec![41.0])),
3874+
Arc::new(Float64Array::from(vec![51.0])),
3875+
],
3876+
)?,
3877+
RecordBatch::try_new(
3878+
Arc::clone(&schema),
3879+
vec![
3880+
Arc::new(Int64Array::from(vec![3])),
3881+
Arc::new(Float64Array::from(vec![12.0])),
3882+
Arc::new(Float64Array::from(vec![22.0])),
3883+
Arc::new(Float64Array::from(vec![32.0])),
3884+
Arc::new(Float64Array::from(vec![42.0])),
3885+
Arc::new(Float64Array::from(vec![52.0])),
3886+
],
3887+
)?,
3888+
]];
3889+
3890+
let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3891+
3892+
let aggr = Arc::new(AggregateExec::try_new(
3893+
AggregateMode::Single,
3894+
PhysicalGroupBy::new(
3895+
vec![(col("g", schema.as_ref())?, "g".to_string())],
3896+
vec![],
3897+
vec![vec![false]],
3898+
false,
3899+
),
3900+
vec![
3901+
Arc::new(
3902+
AggregateExprBuilder::new(
3903+
avg_udaf(),
3904+
vec![col("a", schema.as_ref())?],
3905+
)
3906+
.schema(Arc::clone(&schema))
3907+
.alias("AVG(a)")
3908+
.build()?,
3909+
),
3910+
Arc::new(
3911+
AggregateExprBuilder::new(
3912+
avg_udaf(),
3913+
vec![col("b", schema.as_ref())?],
3914+
)
3915+
.schema(Arc::clone(&schema))
3916+
.alias("AVG(b)")
3917+
.build()?,
3918+
),
3919+
Arc::new(
3920+
AggregateExprBuilder::new(
3921+
avg_udaf(),
3922+
vec![col("c", schema.as_ref())?],
3923+
)
3924+
.schema(Arc::clone(&schema))
3925+
.alias("AVG(c)")
3926+
.build()?,
3927+
),
3928+
Arc::new(
3929+
AggregateExprBuilder::new(
3930+
avg_udaf(),
3931+
vec![col("d", schema.as_ref())?],
3932+
)
3933+
.schema(Arc::clone(&schema))
3934+
.alias("AVG(d)")
3935+
.build()?,
3936+
),
3937+
Arc::new(
3938+
AggregateExprBuilder::new(
3939+
avg_udaf(),
3940+
vec![col("e", schema.as_ref())?],
3941+
)
3942+
.schema(Arc::clone(&schema))
3943+
.alias("AVG(e)")
3944+
.build()?,
3945+
),
3946+
],
3947+
vec![None, None, None, None, None],
3948+
Arc::new(scan) as Arc<dyn ExecutionPlan>,
3949+
Arc::clone(&schema),
3950+
)?);
3951+
3952+
// Pool must be large enough for accumulation to start but too small for
3953+
// sort_memory after clearing.
3954+
let task_ctx = new_spill_ctx(1, 500);
3955+
let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;
3956+
3957+
match &result {
3958+
Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
3959+
Err(e) => {
3960+
let root = e.find_root();
3961+
assert!(
3962+
matches!(root, DataFusionError::ResourcesExhausted(_)),
3963+
"Expected ResourcesExhausted, got: {root}",
3964+
);
3965+
let msg = root.to_string();
3966+
assert!(
3967+
msg.contains("Failed to reserve memory for sort during spill"),
3968+
"Expected sort reservation error, got: {msg}",
3969+
);
3970+
}
3971+
}
3972+
3973+
Ok(())
3974+
}
3975+
38403976
/// Tests that PartialReduce mode:
38413977
/// 1. Accepts state as input (like Final)
38423978
/// 2. Produces state as output (like Partial)

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,16 @@ use crate::aggregates::{
3030
create_schema, evaluate_group_by, evaluate_many, evaluate_optional,
3131
};
3232
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
33-
use crate::sorts::sort::sort_batch;
3433
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
35-
use crate::spill::spill_manager::SpillManager;
34+
use crate::spill::spill_manager::{GetSlicedSize, SpillManager};
3635
use crate::{PhysicalExpr, aggregates, metrics};
3736
use crate::{RecordBatchStream, SendableRecordBatchStream};
3837

3938
use arrow::array::*;
4039
use arrow::datatypes::SchemaRef;
4140
use datafusion_common::{
4241
DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err,
43-
internal_err,
42+
internal_err, resources_datafusion_err,
4443
};
4544
use datafusion_execution::TaskContext;
4645
use datafusion_execution::memory_pool::proxy::VecAllocExt;
@@ -51,7 +50,9 @@ use datafusion_physical_expr::expressions::Column;
5150
use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr};
5251
use datafusion_physical_expr_common::sort_expr::LexOrdering;
5352

53+
use crate::sorts::IncrementalSortIterator;
5454
use datafusion_common::instant::Instant;
55+
use datafusion_common::utils::memory::get_record_batch_memory_size;
5556
use futures::ready;
5657
use futures::stream::{Stream, StreamExt};
5758
use log::debug;
@@ -1060,10 +1061,27 @@ impl GroupedHashAggregateStream {
10601061

10611062
fn update_memory_reservation(&mut self) -> Result<()> {
10621063
let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
1063-
let new_size = acc
1064+
let groups_and_acc_size = acc
10641065
+ self.group_values.size()
10651066
+ self.group_ordering.size()
10661067
+ self.current_group_indices.allocated_size();
1068+
1069+
// Reserve extra headroom for sorting during potential spill.
1070+
// When OOM triggers, group_aggregate_batch has already processed the
1071+
// latest input batch, so the internal state may have grown well beyond
1072+
// the last successful reservation. The emit batch reflects this larger
1073+
// actual state, and the sort needs memory proportional to it.
1074+
// By reserving headroom equal to the data size, we trigger OOM earlier
1075+
// (before too much data accumulates), ensuring the freed reservation
1076+
// after clear_shrink is sufficient to cover the sort memory.
1077+
let sort_headroom =
1078+
if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() {
1079+
acc + self.group_values.size()
1080+
} else {
1081+
0
1082+
};
1083+
1084+
let new_size = groups_and_acc_size + sort_headroom;
10671085
let reservation_result = self.reservation.try_resize(new_size);
10681086

10691087
if reservation_result.is_ok() {
@@ -1122,17 +1140,47 @@ impl GroupedHashAggregateStream {
11221140
let Some(emit) = self.emit(EmitTo::All, true)? else {
11231141
return Ok(());
11241142
};
1125-
let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
11261143

1127-
// Spill sorted state to disk
1144+
// Free accumulated state now that data has been emitted into `emit`.
1145+
// This must happen before reserving sort memory so the pool has room.
1146+
// Use 0 to minimize allocated capacity and maximize memory available for sorting.
1147+
self.clear_shrink(0);
1148+
self.update_memory_reservation()?;
1149+
1150+
let batch_size_ratio = self.batch_size as f32 / emit.num_rows() as f32;
1151+
let batch_memory = get_record_batch_memory_size(&emit);
1152+
// The maximum worst case for a sort is 2X the original underlying buffers(regardless of slicing)
1153+
// First we get the underlying buffers' size, then we get the sliced("actual") size of the batch,
1154+
// and multiply it by the ratio of batch_size to actual size to get the estimated memory needed for sorting the batch.
1155+
// If something goes wrong in get_sliced_size()(double counting or something),
1156+
// we fall back to the worst case.
1157+
let sort_memory = (batch_memory
1158+
+ (emit.get_sliced_size()? as f32 * batch_size_ratio) as usize)
1159+
.min(batch_memory * 2);
1160+
1161+
// If we can't grow even that, we have no choice but to return an error since we can't spill to disk without sorting the data first.
1162+
self.reservation.try_grow(sort_memory).map_err(|err| {
1163+
resources_datafusion_err!(
1164+
"Failed to reserve memory for sort during spill: {err}"
1165+
)
1166+
})?;
1167+
1168+
let sorted_iter = IncrementalSortIterator::new(
1169+
emit,
1170+
self.spill_state.spill_expr.clone(),
1171+
self.batch_size,
1172+
);
11281173
let spillfile = self
11291174
.spill_state
11301175
.spill_manager
1131-
.spill_record_batch_by_size_and_return_max_batch_memory(
1132-
&sorted,
1176+
.spill_record_batch_iter_and_return_max_batch_memory(
1177+
sorted_iter,
11331178
"HashAggSpill",
1134-
self.batch_size,
11351179
)?;
1180+
1181+
// Shrink the memory we allocated for sorting as the sorting is fully done at this point.
1182+
self.reservation.shrink(sort_memory);
1183+
11361184
match spillfile {
11371185
Some((spillfile, max_record_batch_memory)) => {
11381186
self.spill_state.spills.push(SortedSpillFile {
@@ -1150,14 +1198,14 @@ impl GroupedHashAggregateStream {
11501198
Ok(())
11511199
}
11521200

1153-
/// Clear memory and shirk capacities to the size of the batch.
1201+
/// Clear memory and shrink capacities to the given number of rows.
11541202
fn clear_shrink(&mut self, num_rows: usize) {
11551203
self.group_values.clear_shrink(num_rows);
11561204
self.current_group_indices.clear();
11571205
self.current_group_indices.shrink_to(num_rows);
11581206
}
11591207

1160-
/// Clear memory and shirk capacities to zero.
1208+
/// Clear memory and shrink capacities to zero.
11611209
fn clear_all(&mut self) {
11621210
self.clear_shrink(0);
11631211
}
@@ -1196,7 +1244,7 @@ impl GroupedHashAggregateStream {
11961244
// instead.
11971245
// Spilling to disk and reading back also ensures batch size is consistent
11981246
// rather than potentially having one significantly larger last batch.
1199-
self.spill()?; // TODO: use sort_batch_chunked instead?
1247+
self.spill()?;
12001248

12011249
// Mark that we're switching to stream merging mode.
12021250
self.spill_state.is_stream_merging = true;

datafusion/physical-plan/src/sorts/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ pub mod sort;
2626
pub mod sort_preserving_merge;
2727
mod stream;
2828
pub mod streaming_merge;
29+
30+
pub(crate) use stream::IncrementalSortIterator;

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use crate::metrics::{
3939
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics,
4040
};
4141
use crate::projection::{ProjectionExec, make_with_child, update_ordering};
42+
use crate::sorts::IncrementalSortIterator;
4243
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
4344
use crate::spill::get_record_batch_memory_size;
4445
use crate::spill::in_progress_spill_file::InProgressSpillFile;
@@ -728,7 +729,6 @@ impl ExternalSorter {
728729

729730
// Sort the batch immediately and get all output batches
730731
let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?;
731-
drop(batch);
732732

733733
// Free the old reservation and grow it to match the actual sorted output size
734734
reservation.free();
@@ -853,11 +853,13 @@ pub(crate) fn get_reserved_bytes_for_record_batch_size(
853853
/// Estimate how much memory is needed to sort a `RecordBatch`.
854854
/// This will just call `get_reserved_bytes_for_record_batch_size` with the
855855
/// memory size of the record batch and its sliced size.
856-
pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result<usize> {
857-
Ok(get_reserved_bytes_for_record_batch_size(
858-
get_record_batch_memory_size(batch),
859-
batch.get_sliced_size()?,
860-
))
856+
pub(crate) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result<usize> {
857+
batch.get_sliced_size().map(|sliced_size| {
858+
get_reserved_bytes_for_record_batch_size(
859+
get_record_batch_memory_size(batch),
860+
sliced_size,
861+
)
862+
})
861863
}
862864

863865
impl Debug for ExternalSorter {
@@ -900,38 +902,7 @@ pub fn sort_batch_chunked(
900902
expressions: &LexOrdering,
901903
batch_size: usize,
902904
) -> Result<Vec<RecordBatch>> {
903-
let sort_columns = expressions
904-
.iter()
905-
.map(|expr| expr.evaluate_to_sort_column(batch))
906-
.collect::<Result<Vec<_>>>()?;
907-
908-
let indices = lexsort_to_indices(&sort_columns, None)?;
909-
910-
// Split indices into chunks of batch_size
911-
let num_rows = indices.len();
912-
let num_chunks = num_rows.div_ceil(batch_size);
913-
914-
let result_batches = (0..num_chunks)
915-
.map(|chunk_idx| {
916-
let start = chunk_idx * batch_size;
917-
let end = (start + batch_size).min(num_rows);
918-
let chunk_len = end - start;
919-
920-
// Create a slice of indices for this chunk
921-
let chunk_indices = indices.slice(start, chunk_len);
922-
923-
// Take the columns using this chunk of indices
924-
let columns = take_arrays(batch.columns(), &chunk_indices, None)?;
925-
926-
let options = RecordBatchOptions::new().with_row_count(Some(chunk_len));
927-
let chunk_batch =
928-
RecordBatch::try_new_with_options(batch.schema(), columns, &options)?;
929-
930-
Ok(chunk_batch)
931-
})
932-
.collect::<Result<Vec<RecordBatch>>>()?;
933-
934-
Ok(result_batches)
905+
IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect()
935906
}
936907

937908
/// Sort execution plan.

0 commit comments

Comments
 (0)