From 9173fa2d4849f3a64689a91c48fe2e473eb4a0dd Mon Sep 17 00:00:00 2001 From: Nathan Bezualem Date: Mon, 6 Apr 2026 21:21:59 -0400 Subject: [PATCH 1/2] feat: enable external reclaim for mem spillable df operators --- datafusion/execution/src/memory_pool/mod.rs | 48 +++++- datafusion/execution/src/memory_pool/pool.rs | 157 ++++++++++++++++++- 2 files changed, 202 insertions(+), 3 deletions(-) diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 0b4eb3786f555..41fdd109de782 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -205,6 +205,19 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// On error the `allocation` will not be increased in size fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>; + /// Attempt to reclaim `target_bytes` from existing spillable consumers already registered + /// with this pool. + /// + /// `exclude_consumer_id`, when provided, identifies the current requester and should not be + /// reclaimed from to avoid re-entering the same operator while it is mid-allocation. + fn reclaim( + &self, + _target_bytes: usize, + _exclude_consumer_id: Option, + ) -> Result { + Ok(0) + } + /// Return the total amount of memory reserved fn reserved(&self) -> usize; @@ -240,11 +253,22 @@ pub enum MemoryLimit { /// For help with allocation accounting, see the [`proxy`] module. /// /// [proxy]: datafusion_common::utils::proxy -#[derive(Debug)] pub struct MemoryConsumer { name: String, can_spill: bool, id: usize, + reclaimer: Option>, +} + +impl std::fmt::Debug for MemoryConsumer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemoryConsumer") + .field("name", &self.name) + .field("can_spill", &self.can_spill) + .field("id", &self.id) + .field("has_reclaimer", &self.reclaimer.is_some()) + .finish() + } } impl PartialEq for MemoryConsumer { @@ -283,6 +307,7 @@ impl MemoryConsumer { name: name.into(), can_spill: false, id: Self::new_unique_id(), + reclaimer: None, } } @@ -294,6 +319,7 @@ impl MemoryConsumer { name: self.name.clone(), can_spill: self.can_spill, id: Self::new_unique_id(), + reclaimer: self.reclaimer.clone(), } } @@ -307,6 +333,15 @@ impl MemoryConsumer { Self { can_spill, ..self } } + /// Configure a callback that can reclaim memory from this consumer when another consumer in + /// the same pool is under pressure. + pub fn with_reclaimer(self, reclaimer: Arc) -> Self { + Self { + reclaimer: Some(reclaimer), + ..self + } + } + /// Returns true if this allocation can spill to disk pub fn can_spill(&self) -> bool { self.can_spill @@ -317,6 +352,11 @@ impl MemoryConsumer { &self.name } + /// Returns the reclaim callback registered for this consumer, if any. + pub fn reclaimer(&self) -> Option> { + self.reclaimer.clone() + } + /// Registers this [`MemoryConsumer`] with the provided [`MemoryPool`] returning /// a [`MemoryReservation`] that can be used to grow or shrink the memory reservation pub fn register(self, pool: &Arc) -> MemoryReservation { @@ -331,6 +371,12 @@ impl MemoryConsumer { } } +/// Callback implemented by spillable operators that can synchronously reclaim existing +/// reservations when another consumer in the same pool is under pressure. +pub trait MemoryReclaimer: Send + Sync { + fn reclaim(&self, target_bytes: usize) -> Result; +} + /// A registration of a [`MemoryConsumer`] with a [`MemoryPool`]. /// /// Calls [`MemoryPool::unregister`] on drop to return any memory to diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index b10270851cc06..10299a4d52b7f 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -16,7 +16,8 @@ // under the License. use crate::memory_pool::{ - MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, human_readable_size, + MemoryConsumer, MemoryLimit, MemoryPool, MemoryReclaimer, MemoryReservation, + human_readable_size, }; use datafusion_common::HashMap; use datafusion_common::{DataFusionError, Result, resources_datafusion_err}; @@ -24,6 +25,7 @@ use log::debug; use parking_lot::Mutex; use std::{ num::NonZeroUsize, + sync::Arc, sync::atomic::{AtomicUsize, Ordering}, }; @@ -269,12 +271,24 @@ fn insufficient_capacity_err( ) } -#[derive(Debug)] struct TrackedConsumer { name: String, can_spill: bool, reserved: AtomicUsize, peak: AtomicUsize, + reclaimer: Option>, +} + +impl std::fmt::Debug for TrackedConsumer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrackedConsumer") + .field("name", &self.name) + .field("can_spill", &self.can_spill) + .field("reserved", &self.reserved()) + .field("peak", &self.peak()) + .field("has_reclaimer", &self.reclaimer.is_some()) + .finish() + } } impl TrackedConsumer { @@ -428,6 +442,7 @@ impl MemoryPool for TrackConsumersPool { can_spill: consumer.can_spill(), reserved: Default::default(), peak: Default::default(), + reclaimer: consumer.reclaimer(), }, ); @@ -488,6 +503,50 @@ impl MemoryPool for TrackConsumersPool { Ok(()) } + fn reclaim( + &self, + target_bytes: usize, + exclude_consumer_id: Option, + ) -> Result { + if target_bytes == 0 { + return Ok(0); + } + + let mut candidates = self + .tracked_consumers + .lock() + .iter() + .filter_map(|(consumer_id, tracked_consumer)| { + let reserved = tracked_consumer.reserved(); + let reclaimer = tracked_consumer.reclaimer.as_ref()?; + if exclude_consumer_id == Some(*consumer_id) + || !tracked_consumer.can_spill + || reserved == 0 + { + return None; + } + + Some((*consumer_id, reserved, Arc::clone(reclaimer))) + }) + .collect::>(); + candidates.sort_by( + |(left_id, left_reserved, _), (right_id, right_reserved, _)| { + right_reserved + .cmp(left_reserved) + .then_with(|| left_id.cmp(right_id)) + }, + ); + + let mut reclaimed = 0; + for (_, _, reclaimer) in candidates { + if reclaimed >= target_bytes { + break; + } + reclaimed += reclaimer.reclaim(target_bytes - reclaimed)?; + } + Ok(reclaimed) + } + fn reserved(&self) -> usize { self.inner.reserved() } @@ -513,6 +572,24 @@ mod tests { use insta::{Settings, allow_duplicates, assert_snapshot}; use std::sync::Arc; + #[derive(Debug)] + struct TestReclaimer { + reservation: Arc>>>, + } + + impl MemoryReclaimer for TestReclaimer { + fn reclaim(&self, target_bytes: usize) -> Result { + let Some(reservation) = self.reservation.lock().clone() else { + return Ok(0); + }; + let reclaimed = reservation.size().min(target_bytes); + if reclaimed > 0 { + reservation.shrink(reclaimed); + } + Ok(reclaimed) + } + } + fn make_settings() -> Settings { let mut settings = Settings::clone_current(); settings.add_filter( @@ -811,4 +888,80 @@ mod tests { r1#[ID](can spill: false) consumed 20.0 B, peak 20.0 B. "); } + + #[test] + fn test_tracked_consumers_pool_reclaim_prefers_largest_consumer() { + let pool = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(200), + NonZeroUsize::new(3).unwrap(), + )) as Arc; + + let first_reservation_handle = Arc::new(Mutex::new(None)); + let first = Arc::new( + MemoryConsumer::new("spillable-1") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&first_reservation_handle), + })) + .register(&pool), + ); + *first_reservation_handle.lock() = Some(Arc::clone(&first)); + first.grow(100); + + let second_reservation_handle = Arc::new(Mutex::new(None)); + let second = Arc::new( + MemoryConsumer::new("spillable-2") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&second_reservation_handle), + })) + .register(&pool), + ); + *second_reservation_handle.lock() = Some(Arc::clone(&second)); + second.grow(60); + + let reclaimed = pool.reclaim(80, None).unwrap(); + + assert_eq!(reclaimed, 80); + assert_eq!(first.size(), 20); + assert_eq!(second.size(), 60); + } + + #[test] + fn test_tracked_consumers_pool_reclaim_excludes_requester() { + let pool = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(200), + NonZeroUsize::new(3).unwrap(), + )) as Arc; + + let first_reservation_handle = Arc::new(Mutex::new(None)); + let first = Arc::new( + MemoryConsumer::new("spillable-1") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&first_reservation_handle), + })) + .register(&pool), + ); + *first_reservation_handle.lock() = Some(Arc::clone(&first)); + first.grow(100); + + let second_reservation_handle = Arc::new(Mutex::new(None)); + let second = Arc::new( + MemoryConsumer::new("spillable-2") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&second_reservation_handle), + })) + .register(&pool), + ); + *second_reservation_handle.lock() = Some(Arc::clone(&second)); + second.grow(60); + + let reclaimed = pool.reclaim(80, Some(first.consumer().id())).unwrap(); + + assert_eq!(reclaimed, 60); + assert_eq!(first.size(), 100); + assert_eq!(second.size(), 0); + } } From 4bf1cafba154820a60a0aa561a9aecbc78c9f132 Mon Sep 17 00:00:00 2001 From: Nathan Bezualem Date: Fri, 17 Apr 2026 11:23:10 -0400 Subject: [PATCH 2/2] fix: tighten memory reclaimer ownership and fallback behavior --- datafusion/execution/src/memory_pool/mod.rs | 50 ++++++++++++- datafusion/execution/src/memory_pool/pool.rs | 76 ++++++++++++++++---- 2 files changed, 113 insertions(+), 13 deletions(-) diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 41fdd109de782..24b6581f7c818 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -314,12 +314,17 @@ impl MemoryConsumer { /// Returns a clone of this [`MemoryConsumer`] with a new unique id, /// which can be registered with a [`MemoryPool`], /// This new consumer is separate from the original. + /// + /// The cloned consumer intentionally does not inherit any registered + /// [`MemoryReclaimer`]. Reclaimers are expected to be tied to the original + /// spillable operator state, and carrying them across a new consumer id can + /// cause externally-triggered reclaim to target the wrong owner. pub fn clone_with_new_id(&self) -> Self { Self { name: self.name.clone(), can_spill: self.can_spill, id: Self::new_unique_id(), - reclaimer: self.reclaimer.clone(), + reclaimer: None, } } @@ -335,8 +340,11 @@ impl MemoryConsumer { /// Configure a callback that can reclaim memory from this consumer when another consumer in /// the same pool is under pressure. + /// + /// A consumer with a reclaimer is considered spill-capable by default. pub fn with_reclaimer(self, reclaimer: Arc) -> Self { Self { + can_spill: true, reclaimer: Some(reclaimer), ..self } @@ -373,6 +381,13 @@ impl MemoryConsumer { /// Callback implemented by spillable operators that can synchronously reclaim existing /// reservations when another consumer in the same pool is under pressure. +/// +/// Implementations should: +/// +/// - only report bytes actually released from pool-tracked reservations +/// - return at most `target_bytes` +/// - avoid holding strong reference cycles back to [`MemoryReservation`]s, +/// [`MemoryConsumer`]s, or other state that keeps those objects alive pub trait MemoryReclaimer: Send + Sync { fn reclaim(&self, target_bytes: usize) -> Result; } @@ -549,6 +564,16 @@ impl Drop for MemoryReservation { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + + #[derive(Debug)] + struct NoopReclaimer; + + impl MemoryReclaimer for NoopReclaimer { + fn reclaim(&self, _target_bytes: usize) -> Result { + Ok(0) + } + } #[test] fn test_id_uniqueness() { @@ -675,4 +700,27 @@ mod tests { assert_eq!(r1.size(), 0); assert_eq!(pool.reserved(), 80); } + + #[test] + fn test_clone_with_new_id_drops_reclaimer() { + let consumer = + MemoryConsumer::new("spillable").with_reclaimer(Arc::new(NoopReclaimer)); + + let clone = consumer.clone_with_new_id(); + + assert!(consumer.reclaimer().is_some()); + assert!(consumer.can_spill()); + assert!(clone.reclaimer().is_none()); + assert!(clone.can_spill()); + assert_ne!(consumer.id(), clone.id()); + } + + #[test] + fn test_with_reclaimer_marks_consumer_spillable() { + let consumer = + MemoryConsumer::new("spillable").with_reclaimer(Arc::new(NoopReclaimer)); + + assert!(consumer.can_spill()); + assert!(consumer.reclaimer().is_some()); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 10299a4d52b7f..fb00f9cddd0ac 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -544,7 +544,14 @@ impl MemoryPool for TrackConsumersPool { } reclaimed += reclaimer.reclaim(target_bytes - reclaimed)?; } - Ok(reclaimed) + if reclaimed >= target_bytes { + return Ok(reclaimed); + } + + Ok(reclaimed + + self + .inner + .reclaim(target_bytes - reclaimed, exclude_consumer_id)?) } fn reserved(&self) -> usize { @@ -570,16 +577,16 @@ fn provide_top_memory_consumers_to_error_msg( mod tests { use super::*; use insta::{Settings, allow_duplicates, assert_snapshot}; - use std::sync::Arc; + use std::sync::{Arc, Weak}; #[derive(Debug)] struct TestReclaimer { - reservation: Arc>>>, + reservation: Arc>>, } impl MemoryReclaimer for TestReclaimer { fn reclaim(&self, target_bytes: usize) -> Result { - let Some(reservation) = self.reservation.lock().clone() else { + let Some(reservation) = self.reservation.lock().upgrade() else { return Ok(0); }; let reclaimed = reservation.size().min(target_bytes); @@ -896,7 +903,7 @@ mod tests { NonZeroUsize::new(3).unwrap(), )) as Arc; - let first_reservation_handle = Arc::new(Mutex::new(None)); + let first_reservation_handle = Arc::new(Mutex::new(Weak::new())); let first = Arc::new( MemoryConsumer::new("spillable-1") .with_can_spill(true) @@ -905,10 +912,10 @@ mod tests { })) .register(&pool), ); - *first_reservation_handle.lock() = Some(Arc::clone(&first)); + *first_reservation_handle.lock() = Arc::downgrade(&first); first.grow(100); - let second_reservation_handle = Arc::new(Mutex::new(None)); + let second_reservation_handle = Arc::new(Mutex::new(Weak::new())); let second = Arc::new( MemoryConsumer::new("spillable-2") .with_can_spill(true) @@ -917,7 +924,7 @@ mod tests { })) .register(&pool), ); - *second_reservation_handle.lock() = Some(Arc::clone(&second)); + *second_reservation_handle.lock() = Arc::downgrade(&second); second.grow(60); let reclaimed = pool.reclaim(80, None).unwrap(); @@ -934,7 +941,7 @@ mod tests { NonZeroUsize::new(3).unwrap(), )) as Arc; - let first_reservation_handle = Arc::new(Mutex::new(None)); + let first_reservation_handle = Arc::new(Mutex::new(Weak::new())); let first = Arc::new( MemoryConsumer::new("spillable-1") .with_can_spill(true) @@ -943,10 +950,10 @@ mod tests { })) .register(&pool), ); - *first_reservation_handle.lock() = Some(Arc::clone(&first)); + *first_reservation_handle.lock() = Arc::downgrade(&first); first.grow(100); - let second_reservation_handle = Arc::new(Mutex::new(None)); + let second_reservation_handle = Arc::new(Mutex::new(Weak::new())); let second = Arc::new( MemoryConsumer::new("spillable-2") .with_can_spill(true) @@ -955,7 +962,7 @@ mod tests { })) .register(&pool), ); - *second_reservation_handle.lock() = Some(Arc::clone(&second)); + *second_reservation_handle.lock() = Arc::downgrade(&second); second.grow(60); let reclaimed = pool.reclaim(80, Some(first.consumer().id())).unwrap(); @@ -964,4 +971,49 @@ mod tests { assert_eq!(first.size(), 100); assert_eq!(second.size(), 0); } + + #[derive(Debug, Default)] + struct InnerReclaimOnlyPool { + reclaimed: AtomicUsize, + } + + impl MemoryPool for InnerReclaimOnlyPool { + fn grow(&self, _reservation: &MemoryReservation, _additional: usize) {} + + fn shrink(&self, _reservation: &MemoryReservation, _shrink: usize) {} + + fn try_grow( + &self, + _reservation: &MemoryReservation, + _additional: usize, + ) -> Result<()> { + Ok(()) + } + + fn reclaim( + &self, + target_bytes: usize, + _exclude_consumer_id: Option, + ) -> Result { + self.reclaimed.fetch_add(target_bytes, Ordering::Relaxed); + Ok(target_bytes) + } + + fn reserved(&self) -> usize { + 0 + } + } + + #[test] + fn test_tracked_consumers_pool_reclaim_falls_back_to_inner_pool() { + let pool = TrackConsumersPool::new( + InnerReclaimOnlyPool::default(), + NonZeroUsize::new(3).unwrap(), + ); + + let reclaimed = pool.reclaim(64, None).unwrap(); + + assert_eq!(reclaimed, 64); + assert_eq!(pool.inner.reclaimed.load(Ordering::Relaxed), 64); + } }