Skip to content

Commit 49631eb

Browse files
committed
feat: add public cache_info() and clear_cache() API to AnophelesBase
Add two public methods to AnophelesBase so that all resource classes (Ag3, Af1, Adir1, Amin1, Adar1, and Plasmodium resources) inherit a stable, documented way to inspect and release in-memory caches. This addresses silent memory pressure in long-running Jupyter/Colab sessions where cached haplotype, CNV, and SNP datasets accumulate without any public mechanism to reclaim the memory. - cache_info() returns a dict keyed by cache attribute name with entry count, estimated byte size, cache kind, and a note on the estimation method used (xarray.nbytes, numpy.nbytes, dask upper bound, bytes length, or sys.getsizeof shallow). - clear_cache(category="all") clears all or a specific category of caches. Supported categories: all, base, sample_metadata, genome_features, genome_sequence, snp, haplotypes, cnv, aim. Unknown categories raise ValueError listing valid options. Caches repopulate on demand after clearing, so calling clear_cache() is always safe mid-session. Closes #1289
1 parent 1da4a14 commit 49631eb

2 files changed

Lines changed: 399 additions & 0 deletions

File tree

malariagen_data/anoph/base.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
import sys
23
import warnings
34

45
import json
6+
from functools import _lru_cache_wrapper
57
from contextlib import nullcontext
68
from datetime import date
79
from pathlib import Path
@@ -188,6 +190,181 @@ def __init__(
188190
if results_cache is not None:
189191
self._results_cache = Path(results_cache).expanduser().resolve()
190192

193+
# Mapping from category names to cache attribute name prefixes/patterns.
194+
_CACHE_CATEGORIES: Dict[str, Tuple[str, ...]] = {
195+
"base": (
196+
"_cache_releases",
197+
"_cache_available_releases",
198+
"_cache_sample_sets",
199+
"_cache_available_sample_sets",
200+
"_cache_sample_set_to_release",
201+
"_cache_sample_set_to_study",
202+
"_cache_sample_set_to_study_info",
203+
"_cache_sample_set_to_terms_of_use_info",
204+
"_cache_files",
205+
),
206+
"sample_metadata": (
207+
"_cache_sample_metadata",
208+
"_cache_cohorts",
209+
"_cache_cohort_geometries",
210+
),
211+
"genome_features": ("_cache_genome_features",),
212+
"genome_sequence": ("_cache_genome",),
213+
"snp": (
214+
"_cache_snp_sites",
215+
"_cache_snp_genotypes",
216+
"_cache_site_filters",
217+
"_cache_site_annotations",
218+
"_cache_locate_site_class",
219+
"_cached_snp_calls",
220+
),
221+
"haplotypes": (
222+
"_cache_haplotypes",
223+
"_cache_haplotype_sites",
224+
),
225+
"cnv": (
226+
"_cache_cnv_hmm",
227+
"_cache_cnv_coverage_calls",
228+
"_cache_cnv_discordant_read_calls",
229+
),
230+
"aim": ("_cache_aim_variants",),
231+
}
232+
233+
def _iter_cache_attrs(self, category="all"):
234+
"""Yield (attr_name, obj) pairs for cache attributes on this instance.
235+
236+
Parameters
237+
----------
238+
category : str
239+
A cache category name, or ``"all"`` to iterate over every cache.
240+
"""
241+
if category == "all":
242+
# Gather all attribute names from every category.
243+
attr_names = set()
244+
for names in self._CACHE_CATEGORIES.values():
245+
attr_names.update(names)
246+
else:
247+
if category not in self._CACHE_CATEGORIES:
248+
valid = sorted(["all"] + list(self._CACHE_CATEGORIES.keys()))
249+
raise ValueError(
250+
f"Unknown cache category {category!r}. "
251+
f"Valid options: {', '.join(repr(v) for v in valid)}"
252+
)
253+
attr_names = set(self._CACHE_CATEGORIES[category])
254+
255+
for attr_name in sorted(attr_names):
256+
obj = getattr(self, attr_name, None)
257+
if obj is not None:
258+
yield attr_name, obj
259+
260+
@staticmethod
261+
def _estimate_cache_entry_nbytes(value):
262+
"""Best-effort deep size estimate for a single cached value."""
263+
try:
264+
import xarray as xr
265+
266+
if isinstance(value, (xr.Dataset, xr.DataArray)):
267+
return value.nbytes, "xarray.nbytes"
268+
except ImportError: # pragma: no cover
269+
pass
270+
271+
try:
272+
import numpy as np
273+
274+
if isinstance(value, np.ndarray):
275+
return value.nbytes, "numpy.nbytes"
276+
except ImportError: # pragma: no cover
277+
pass
278+
279+
try:
280+
import dask.array
281+
282+
if isinstance(value, dask.array.Array):
283+
return value.nbytes, "dask upper bound"
284+
except ImportError: # pragma: no cover
285+
pass
286+
287+
if isinstance(value, bytes):
288+
return len(value), "bytes length"
289+
290+
return sys.getsizeof(value), "sys.getsizeof shallow"
291+
292+
@doc(
293+
summary="""
294+
Return information about in-memory caches held by this instance.
295+
""",
296+
returns="""
297+
A dictionary keyed by cache attribute name. Each value is a
298+
dictionary with keys ``'entries'``, ``'nbytes'``, ``'kind'``,
299+
and ``'note'``.
300+
""",
301+
)
302+
def cache_info(self) -> Dict[str, Dict[str, Any]]:
303+
info: Dict[str, Dict[str, Any]] = {}
304+
305+
for attr_name, obj in self._iter_cache_attrs("all"):
306+
if isinstance(obj, _lru_cache_wrapper):
307+
ci = obj.cache_info()
308+
info[attr_name] = {
309+
"entries": ci.currsize,
310+
"nbytes": 0,
311+
"kind": "lru_cache",
312+
"note": "size not estimated for lru_cache",
313+
}
314+
elif isinstance(obj, dict):
315+
total_nbytes = 0
316+
notes = set()
317+
for v in obj.values():
318+
nb, note = self._estimate_cache_entry_nbytes(v)
319+
total_nbytes += nb
320+
notes.add(note)
321+
info[attr_name] = {
322+
"entries": len(obj),
323+
"nbytes": total_nbytes,
324+
"kind": "dict",
325+
"note": ", ".join(sorted(notes)) if notes else "empty",
326+
}
327+
else:
328+
# Single-value caches (e.g. zarr groups stored as Optional).
329+
nb, note = self._estimate_cache_entry_nbytes(obj)
330+
info[attr_name] = {
331+
"entries": 1,
332+
"nbytes": nb,
333+
"kind": "other",
334+
"note": note,
335+
}
336+
337+
return info
338+
339+
@doc(
340+
summary="""
341+
Clear in-memory caches to free memory.
342+
""",
343+
extended_summary="""
344+
This is useful in long-running sessions (e.g., Jupyter notebooks
345+
or Google Colab) where cached data accumulates and causes memory
346+
pressure. Subsequent data access calls will repopulate the caches
347+
on demand.
348+
""",
349+
parameters=dict(
350+
category="""
351+
The cache category to clear. Use ``"all"`` (default) to clear
352+
every cache. Valid categories include ``"haplotypes"``,
353+
``"cnv"``, ``"snp"``, ``"sample_metadata"``, ``"aim"``,
354+
``"genome_features"``, ``"genome_sequence"``, and ``"base"``.
355+
""",
356+
),
357+
)
358+
def clear_cache(self, category: str = "all") -> None:
359+
for attr_name, obj in self._iter_cache_attrs(category):
360+
if isinstance(obj, _lru_cache_wrapper):
361+
obj.cache_clear()
362+
elif isinstance(obj, dict):
363+
obj.clear()
364+
else:
365+
# Single-value caches — reset to None.
366+
setattr(self, attr_name, None)
367+
191368
def _progress(self, iterable, desc=None, leave=False, **kwargs): # pragma: no cover
192369
# Progress doesn't mix well with debug logging.
193370
show_progress = self._show_progress and not self._debug

0 commit comments

Comments
 (0)