diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 0b4eb3786f555..24b6581f7c818 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -205,6 +205,19 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// On error the `allocation` will not be increased in size fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>; + /// Attempt to reclaim `target_bytes` from existing spillable consumers already registered + /// with this pool. + /// + /// `exclude_consumer_id`, when provided, identifies the current requester and should not be + /// reclaimed from to avoid re-entering the same operator while it is mid-allocation. + fn reclaim( + &self, + _target_bytes: usize, + _exclude_consumer_id: Option, + ) -> Result { + Ok(0) + } + /// Return the total amount of memory reserved fn reserved(&self) -> usize; @@ -240,11 +253,22 @@ pub enum MemoryLimit { /// For help with allocation accounting, see the [`proxy`] module. /// /// [proxy]: datafusion_common::utils::proxy -#[derive(Debug)] pub struct MemoryConsumer { name: String, can_spill: bool, id: usize, + reclaimer: Option>, +} + +impl std::fmt::Debug for MemoryConsumer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemoryConsumer") + .field("name", &self.name) + .field("can_spill", &self.can_spill) + .field("id", &self.id) + .field("has_reclaimer", &self.reclaimer.is_some()) + .finish() + } } impl PartialEq for MemoryConsumer { @@ -283,17 +307,24 @@ impl MemoryConsumer { name: name.into(), can_spill: false, id: Self::new_unique_id(), + reclaimer: None, } } /// Returns a clone of this [`MemoryConsumer`] with a new unique id, /// which can be registered with a [`MemoryPool`], /// This new consumer is separate from the original. + /// + /// The cloned consumer intentionally does not inherit any registered + /// [`MemoryReclaimer`]. Reclaimers are expected to be tied to the original + /// spillable operator state, and carrying them across a new consumer id can + /// cause externally-triggered reclaim to target the wrong owner. pub fn clone_with_new_id(&self) -> Self { Self { name: self.name.clone(), can_spill: self.can_spill, id: Self::new_unique_id(), + reclaimer: None, } } @@ -307,6 +338,18 @@ impl MemoryConsumer { Self { can_spill, ..self } } + /// Configure a callback that can reclaim memory from this consumer when another consumer in + /// the same pool is under pressure. + /// + /// A consumer with a reclaimer is considered spill-capable by default. + pub fn with_reclaimer(self, reclaimer: Arc) -> Self { + Self { + can_spill: true, + reclaimer: Some(reclaimer), + ..self + } + } + /// Returns true if this allocation can spill to disk pub fn can_spill(&self) -> bool { self.can_spill @@ -317,6 +360,11 @@ impl MemoryConsumer { &self.name } + /// Returns the reclaim callback registered for this consumer, if any. + pub fn reclaimer(&self) -> Option> { + self.reclaimer.clone() + } + /// Registers this [`MemoryConsumer`] with the provided [`MemoryPool`] returning /// a [`MemoryReservation`] that can be used to grow or shrink the memory reservation pub fn register(self, pool: &Arc) -> MemoryReservation { @@ -331,6 +379,19 @@ impl MemoryConsumer { } } +/// Callback implemented by spillable operators that can synchronously reclaim existing +/// reservations when another consumer in the same pool is under pressure. +/// +/// Implementations should: +/// +/// - only report bytes actually released from pool-tracked reservations +/// - return at most `target_bytes` +/// - avoid holding strong reference cycles back to [`MemoryReservation`]s, +/// [`MemoryConsumer`]s, or other state that keeps those objects alive +pub trait MemoryReclaimer: Send + Sync { + fn reclaim(&self, target_bytes: usize) -> Result; +} + /// A registration of a [`MemoryConsumer`] with a [`MemoryPool`]. /// /// Calls [`MemoryPool::unregister`] on drop to return any memory to @@ -503,6 +564,16 @@ impl Drop for MemoryReservation { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + + #[derive(Debug)] + struct NoopReclaimer; + + impl MemoryReclaimer for NoopReclaimer { + fn reclaim(&self, _target_bytes: usize) -> Result { + Ok(0) + } + } #[test] fn test_id_uniqueness() { @@ -629,4 +700,27 @@ mod tests { assert_eq!(r1.size(), 0); assert_eq!(pool.reserved(), 80); } + + #[test] + fn test_clone_with_new_id_drops_reclaimer() { + let consumer = + MemoryConsumer::new("spillable").with_reclaimer(Arc::new(NoopReclaimer)); + + let clone = consumer.clone_with_new_id(); + + assert!(consumer.reclaimer().is_some()); + assert!(consumer.can_spill()); + assert!(clone.reclaimer().is_none()); + assert!(clone.can_spill()); + assert_ne!(consumer.id(), clone.id()); + } + + #[test] + fn test_with_reclaimer_marks_consumer_spillable() { + let consumer = + MemoryConsumer::new("spillable").with_reclaimer(Arc::new(NoopReclaimer)); + + assert!(consumer.can_spill()); + assert!(consumer.reclaimer().is_some()); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index b10270851cc06..fb00f9cddd0ac 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -16,7 +16,8 @@ // under the License. use crate::memory_pool::{ - MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, human_readable_size, + MemoryConsumer, MemoryLimit, MemoryPool, MemoryReclaimer, MemoryReservation, + human_readable_size, }; use datafusion_common::HashMap; use datafusion_common::{DataFusionError, Result, resources_datafusion_err}; @@ -24,6 +25,7 @@ use log::debug; use parking_lot::Mutex; use std::{ num::NonZeroUsize, + sync::Arc, sync::atomic::{AtomicUsize, Ordering}, }; @@ -269,12 +271,24 @@ fn insufficient_capacity_err( ) } -#[derive(Debug)] struct TrackedConsumer { name: String, can_spill: bool, reserved: AtomicUsize, peak: AtomicUsize, + reclaimer: Option>, +} + +impl std::fmt::Debug for TrackedConsumer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrackedConsumer") + .field("name", &self.name) + .field("can_spill", &self.can_spill) + .field("reserved", &self.reserved()) + .field("peak", &self.peak()) + .field("has_reclaimer", &self.reclaimer.is_some()) + .finish() + } } impl TrackedConsumer { @@ -428,6 +442,7 @@ impl MemoryPool for TrackConsumersPool { can_spill: consumer.can_spill(), reserved: Default::default(), peak: Default::default(), + reclaimer: consumer.reclaimer(), }, ); @@ -488,6 +503,57 @@ impl MemoryPool for TrackConsumersPool { Ok(()) } + fn reclaim( + &self, + target_bytes: usize, + exclude_consumer_id: Option, + ) -> Result { + if target_bytes == 0 { + return Ok(0); + } + + let mut candidates = self + .tracked_consumers + .lock() + .iter() + .filter_map(|(consumer_id, tracked_consumer)| { + let reserved = tracked_consumer.reserved(); + let reclaimer = tracked_consumer.reclaimer.as_ref()?; + if exclude_consumer_id == Some(*consumer_id) + || !tracked_consumer.can_spill + || reserved == 0 + { + return None; + } + + Some((*consumer_id, reserved, Arc::clone(reclaimer))) + }) + .collect::>(); + candidates.sort_by( + |(left_id, left_reserved, _), (right_id, right_reserved, _)| { + right_reserved + .cmp(left_reserved) + .then_with(|| left_id.cmp(right_id)) + }, + ); + + let mut reclaimed = 0; + for (_, _, reclaimer) in candidates { + if reclaimed >= target_bytes { + break; + } + reclaimed += reclaimer.reclaim(target_bytes - reclaimed)?; + } + if reclaimed >= target_bytes { + return Ok(reclaimed); + } + + Ok(reclaimed + + self + .inner + .reclaim(target_bytes - reclaimed, exclude_consumer_id)?) + } + fn reserved(&self) -> usize { self.inner.reserved() } @@ -511,7 +577,25 @@ fn provide_top_memory_consumers_to_error_msg( mod tests { use super::*; use insta::{Settings, allow_duplicates, assert_snapshot}; - use std::sync::Arc; + use std::sync::{Arc, Weak}; + + #[derive(Debug)] + struct TestReclaimer { + reservation: Arc>>, + } + + impl MemoryReclaimer for TestReclaimer { + fn reclaim(&self, target_bytes: usize) -> Result { + let Some(reservation) = self.reservation.lock().upgrade() else { + return Ok(0); + }; + let reclaimed = reservation.size().min(target_bytes); + if reclaimed > 0 { + reservation.shrink(reclaimed); + } + Ok(reclaimed) + } + } fn make_settings() -> Settings { let mut settings = Settings::clone_current(); @@ -811,4 +895,125 @@ mod tests { r1#[ID](can spill: false) consumed 20.0 B, peak 20.0 B. "); } + + #[test] + fn test_tracked_consumers_pool_reclaim_prefers_largest_consumer() { + let pool = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(200), + NonZeroUsize::new(3).unwrap(), + )) as Arc; + + let first_reservation_handle = Arc::new(Mutex::new(Weak::new())); + let first = Arc::new( + MemoryConsumer::new("spillable-1") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&first_reservation_handle), + })) + .register(&pool), + ); + *first_reservation_handle.lock() = Arc::downgrade(&first); + first.grow(100); + + let second_reservation_handle = Arc::new(Mutex::new(Weak::new())); + let second = Arc::new( + MemoryConsumer::new("spillable-2") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&second_reservation_handle), + })) + .register(&pool), + ); + *second_reservation_handle.lock() = Arc::downgrade(&second); + second.grow(60); + + let reclaimed = pool.reclaim(80, None).unwrap(); + + assert_eq!(reclaimed, 80); + assert_eq!(first.size(), 20); + assert_eq!(second.size(), 60); + } + + #[test] + fn test_tracked_consumers_pool_reclaim_excludes_requester() { + let pool = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(200), + NonZeroUsize::new(3).unwrap(), + )) as Arc; + + let first_reservation_handle = Arc::new(Mutex::new(Weak::new())); + let first = Arc::new( + MemoryConsumer::new("spillable-1") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&first_reservation_handle), + })) + .register(&pool), + ); + *first_reservation_handle.lock() = Arc::downgrade(&first); + first.grow(100); + + let second_reservation_handle = Arc::new(Mutex::new(Weak::new())); + let second = Arc::new( + MemoryConsumer::new("spillable-2") + .with_can_spill(true) + .with_reclaimer(Arc::new(TestReclaimer { + reservation: Arc::clone(&second_reservation_handle), + })) + .register(&pool), + ); + *second_reservation_handle.lock() = Arc::downgrade(&second); + second.grow(60); + + let reclaimed = pool.reclaim(80, Some(first.consumer().id())).unwrap(); + + assert_eq!(reclaimed, 60); + assert_eq!(first.size(), 100); + assert_eq!(second.size(), 0); + } + + #[derive(Debug, Default)] + struct InnerReclaimOnlyPool { + reclaimed: AtomicUsize, + } + + impl MemoryPool for InnerReclaimOnlyPool { + fn grow(&self, _reservation: &MemoryReservation, _additional: usize) {} + + fn shrink(&self, _reservation: &MemoryReservation, _shrink: usize) {} + + fn try_grow( + &self, + _reservation: &MemoryReservation, + _additional: usize, + ) -> Result<()> { + Ok(()) + } + + fn reclaim( + &self, + target_bytes: usize, + _exclude_consumer_id: Option, + ) -> Result { + self.reclaimed.fetch_add(target_bytes, Ordering::Relaxed); + Ok(target_bytes) + } + + fn reserved(&self) -> usize { + 0 + } + } + + #[test] + fn test_tracked_consumers_pool_reclaim_falls_back_to_inner_pool() { + let pool = TrackConsumersPool::new( + InnerReclaimOnlyPool::default(), + NonZeroUsize::new(3).unwrap(), + ); + + let reclaimed = pool.reclaim(64, None).unwrap(); + + assert_eq!(reclaimed, 64); + assert_eq!(pool.inner.reclaimed.load(Ordering::Relaxed), 64); + } }