diff --git a/crates/pet-conda/src/lib.rs b/crates/pet-conda/src/lib.rs index 8c3915bd..1ab3021d 100644 --- a/crates/pet-conda/src/lib.rs +++ b/crates/pet-conda/src/lib.rs @@ -391,9 +391,9 @@ impl Locator for Conda { reporter.report_environment(&env); // Also check for a mamba/micromamba manager in the same directory and report it. - // Reporting inside the closure minimizes the TOCTOU window compared to a - // separate contains_key check, though concurrent threads may still - // briefly both invoke the closure before the write-lock double-check. + // LocatorCache coalesces concurrent lookups for this conda_dir, so mamba + // discovery and its reporting side effect run at most once per in-flight + // key. let _ = self.mamba_managers.get_or_insert_with(conda_dir.clone(), || { let mgr = get_mamba_manager(conda_dir); if let Some(ref m) = mgr { diff --git a/crates/pet-core/src/cache.rs b/crates/pet-core/src/cache.rs index 6cdcacb2..8477e0d7 100644 --- a/crates/pet-core/src/cache.rs +++ b/crates/pet-core/src/cache.rs @@ -6,7 +6,12 @@ //! Provides a thread-safe cache wrapper that consolidates common caching patterns //! used across multiple locators in the codebase. -use std::{collections::HashMap, hash::Hash, path::PathBuf, sync::RwLock}; +use std::{ + collections::HashMap, + hash::Hash, + path::PathBuf, + sync::{Arc, Condvar, Mutex, RwLock}, +}; use crate::{manager::EnvManager, python_environment::PythonEnvironment}; @@ -17,6 +22,77 @@ use crate::{manager::EnvManager, python_environment::PythonEnvironment}; /// returned from the cache. pub struct LocatorCache { cache: RwLock>, + in_flight: Mutex>>>, +} + +struct InFlightEntry { + result: Mutex>>, + changed: Condvar, +} + +#[derive(Clone)] +enum InFlightResult { + Value(Option), + Panicked, +} + +struct InFlightOwnerGuard<'a, K: Eq + Hash, V> { + key: Option, + entry: Arc>, + in_flight: &'a Mutex>>>, +} + +enum InFlightClaim<'a, K: Eq + Hash, V> { + Owner(InFlightOwnerGuard<'a, K, V>), + Waiter(Arc>), +} + +impl InFlightEntry { + fn new() -> Self { + Self { + result: Mutex::new(None), + changed: Condvar::new(), + } + } +} + +impl InFlightOwnerGuard<'_, K, V> { + fn complete(mut self, result: Option) { + self.publish_result(result); + } + + fn publish_result(&mut self, result: Option) { + self.publish(InFlightResult::Value(result)); + } + + fn publish_panic(&mut self) { + self.publish(InFlightResult::Panicked); + } + + fn publish(&mut self, result: InFlightResult) { + *self + .entry + .result + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) = Some(result); + + if let Some(key) = self.key.take() { + self.in_flight + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .remove(&key); + } + + self.entry.changed.notify_all(); + } +} + +impl Drop for InFlightOwnerGuard<'_, K, V> { + fn drop(&mut self) { + if self.key.is_some() { + self.publish_panic(); + } + } } impl LocatorCache { @@ -24,6 +100,7 @@ impl LocatorCache { pub fn new() -> Self { Self { cache: RwLock::new(HashMap::new()), + in_flight: Mutex::new(HashMap::new()), } } @@ -68,15 +145,24 @@ impl LocatorCache { /// Returns a cloned value for the given key if it exists, otherwise computes /// and inserts the value using the provided closure. /// - /// This method first checks with a read lock, then upgrades to a write lock - /// if the value needs to be computed and inserted. + /// This method first checks with a read lock. If the key is missing, it + /// claims a per-key in-flight slot before computing the value so concurrent + /// callers for the same key wait for the first computation instead of + /// running duplicate closures with duplicate side effects. `None` results + /// are shared with current waiters but are not stored in the cache, so later + /// calls can retry the computation. + /// + /// The closure must not call `get_or_insert_with` for the same cache and key, + /// directly or indirectly, because the owner would wait on its own in-flight + /// entry. If the owner panics before publishing a result, waiters for the same + /// key are woken and panic instead of silently receiving an uncached `None`. #[must_use] pub fn get_or_insert_with(&self, key: K, f: F) -> Option where F: FnOnce() -> Option, K: Clone, { - // First check with read lock + // First check with read lock. { let cache = self.cache.read().expect("locator cache lock poisoned"); if let Some(value) = cache.get(&key) { @@ -84,18 +170,80 @@ impl LocatorCache { } } + let in_flight = match self.claim_in_flight(&key) { + InFlightClaim::Owner(in_flight) => in_flight, + InFlightClaim::Waiter(entry) => return Self::wait_for_in_flight(entry), + }; + + // Check again after claiming the in-flight slot. Another thread may have + // completed the same key while this thread was waiting. + { + let cache = self.cache.read().expect("locator cache lock poisoned"); + if let Some(value) = cache.get(&key) { + let result = Some(value.clone()); + in_flight.complete(result.clone()); + return result; + } + } + // Compute the value (outside of any lock) - if let Some(value) = f() { + let result = if let Some(value) = f() { // Acquire write lock and insert let mut cache = self.cache.write().expect("locator cache lock poisoned"); // Double-check in case another thread inserted while we were computing if let Some(existing) = cache.get(&key) { - return Some(existing.clone()); + Some(existing.clone()) + } else { + cache.insert(key, value.clone()); + Some(value) } - cache.insert(key, value.clone()); - Some(value) } else { None + }; + + in_flight.complete(result.clone()); + result + } + + fn claim_in_flight(&self, key: &K) -> InFlightClaim<'_, K, V> + where + K: Clone, + { + let mut in_flight = self + .in_flight + .lock() + .expect("locator cache in-flight lock poisoned"); + + if let Some(entry) = in_flight.get(key) { + return InFlightClaim::Waiter(entry.clone()); + } + + let entry = Arc::new(InFlightEntry::new()); + in_flight.insert(key.clone(), entry.clone()); + InFlightClaim::Owner(InFlightOwnerGuard { + key: Some(key.clone()), + entry, + in_flight: &self.in_flight, + }) + } + + fn wait_for_in_flight(entry: Arc>) -> Option { + let mut result = entry + .result + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + while result.is_none() { + result = entry + .changed + .wait(result) + .unwrap_or_else(|poisoned| poisoned.into_inner()); + } + + match result.clone().unwrap() { + InFlightResult::Value(value) => value, + InFlightResult::Panicked => { + panic!("locator cache in-flight owner panicked before publishing a result"); + } } } @@ -160,6 +308,12 @@ pub type ManagerCache = LocatorCache; #[cfg(test)] mod tests { use super::*; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc, Arc, Barrier, Mutex, + }; + use std::thread; + use std::time::Duration; #[test] fn test_cache_get_and_insert() { @@ -192,6 +346,164 @@ mod tests { assert!(!cache.contains_key(&"key2".to_string())); } + #[test] + fn test_cache_get_or_insert_with_runs_one_closure_per_key() { + let cache: Arc> = Arc::new(LocatorCache::new()); + let barrier = Arc::new(Barrier::new(3)); + let calls = Arc::new(AtomicUsize::new(0)); + let (started_tx, started_rx) = mpsc::channel(); + let (release_tx, release_rx) = mpsc::channel(); + let release_rx = Arc::new(Mutex::new(release_rx)); + let mut handles = vec![]; + + for _ in 0..2 { + let cache = cache.clone(); + let barrier = barrier.clone(); + let calls = calls.clone(); + let started_tx = started_tx.clone(); + let release_rx = release_rx.clone(); + handles.push(thread::spawn(move || { + barrier.wait(); + cache.get_or_insert_with("key".to_string(), || { + calls.fetch_add(1, Ordering::SeqCst); + started_tx.send(()).unwrap(); + release_rx + .lock() + .unwrap() + .recv_timeout(Duration::from_secs(5)) + .unwrap(); + Some(42) + }) + })); + } + + barrier.wait(); + started_rx.recv_timeout(Duration::from_secs(5)).unwrap(); + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert!(started_rx.try_recv().is_err()); + + release_tx.send(()).unwrap(); + release_tx.send(()).unwrap(); + + let mut results = handles + .into_iter() + .map(|handle| handle.join().unwrap()) + .collect::>(); + results.sort(); + + assert_eq!(results, vec![Some(42), Some(42)]); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[test] + fn test_cache_get_or_insert_with_shares_concurrent_none_result() { + let entry = Arc::new(InFlightEntry::new()); + let waiter_entry = entry.clone(); + let waiter = + thread::spawn(move || LocatorCache::::wait_for_in_flight(waiter_entry)); + + *entry + .result + .lock() + .expect("locator cache in-flight result lock poisoned") = + Some(InFlightResult::Value(None)); + entry.changed.notify_all(); + + assert_eq!(waiter.join().unwrap(), None); + + let cache: LocatorCache = LocatorCache::new(); + assert_eq!(cache.get_or_insert_with("key".to_string(), || None), None); + assert!(!cache.contains_key(&"key".to_string())); + + assert_eq!( + cache.get_or_insert_with("key".to_string(), || Some(42)), + Some(42) + ); + } + + #[test] + fn test_cache_get_or_insert_with_panic_releases_in_flight_key() { + let cache: LocatorCache = LocatorCache::new(); + + let result = std::panic::catch_unwind(|| { + let _ = cache.get_or_insert_with("key".to_string(), || -> Option { + panic!("boom"); + }); + }); + + assert!(result.is_err()); + assert!(!cache.contains_key(&"key".to_string())); + assert_eq!( + cache.get_or_insert_with("key".to_string(), || Some(42)), + Some(42) + ); + } + + #[test] + fn test_cache_panicked_owner_wakes_waiters_with_panic() { + let key = "key".to_string(); + let entry = Arc::new(InFlightEntry::new()); + let in_flight: Mutex>>> = Mutex::new(HashMap::new()); + + in_flight.lock().unwrap().insert(key.clone(), entry.clone()); + let owner = InFlightOwnerGuard { + key: Some(key), + entry: entry.clone(), + in_flight: &in_flight, + }; + + drop(owner); + + let waiter_result = + std::panic::catch_unwind(|| LocatorCache::::wait_for_in_flight(entry)); + + assert!(waiter_result.is_err()); + assert!(in_flight.lock().unwrap().is_empty()); + } + + #[test] + fn test_cache_publish_result_recovers_poisoned_in_flight_locks() { + let key = "key".to_string(); + let entry = Arc::new(InFlightEntry::new()); + let in_flight: Arc>>>> = + Arc::new(Mutex::new(HashMap::new())); + + in_flight.lock().unwrap().insert(key.clone(), entry.clone()); + + let poison_entry = entry.clone(); + assert!(thread::spawn(move || { + let _guard = poison_entry.result.lock().unwrap(); + panic!("poison result lock"); + }) + .join() + .is_err()); + + let poison_in_flight = in_flight.clone(); + assert!(thread::spawn(move || { + let _guard = poison_in_flight.lock().unwrap(); + panic!("poison in-flight lock"); + }) + .join() + .is_err()); + + let owner = InFlightOwnerGuard { + key: Some(key), + entry: entry.clone(), + in_flight: &in_flight, + }; + + owner.complete(Some(42)); + + assert_eq!( + LocatorCache::::wait_for_in_flight(entry), + Some(42) + ); + assert!(in_flight + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .is_empty()); + } + #[test] fn test_cache_clear() { let cache: LocatorCache = LocatorCache::new();