Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions malariagen_data/anoph/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys
import warnings

import json
from functools import _lru_cache_wrapper
from contextlib import nullcontext
from datetime import date
from pathlib import Path
Expand Down Expand Up @@ -188,6 +190,181 @@ def __init__(
if results_cache is not None:
self._results_cache = Path(results_cache).expanduser().resolve()

# Mapping from category names to cache attribute name prefixes/patterns.
_CACHE_CATEGORIES: Dict[str, Tuple[str, ...]] = {
"base": (
"_cache_releases",
"_cache_available_releases",
"_cache_sample_sets",
"_cache_available_sample_sets",
"_cache_sample_set_to_release",
"_cache_sample_set_to_study",
"_cache_sample_set_to_study_info",
"_cache_sample_set_to_terms_of_use_info",
"_cache_files",
),
"sample_metadata": (
"_cache_sample_metadata",
"_cache_cohorts",
"_cache_cohort_geometries",
),
"genome_features": ("_cache_genome_features",),
"genome_sequence": ("_cache_genome",),
"snp": (
"_cache_snp_sites",
"_cache_snp_genotypes",
"_cache_site_filters",
"_cache_site_annotations",
"_cache_locate_site_class",
"_cached_snp_calls",
),
"haplotypes": (
"_cache_haplotypes",
"_cache_haplotype_sites",
),
"cnv": (
"_cache_cnv_hmm",
"_cache_cnv_coverage_calls",
"_cache_cnv_discordant_read_calls",
),
"aim": ("_cache_aim_variants",),
}

def _iter_cache_attrs(self, category="all"):
"""Yield (attr_name, obj) pairs for cache attributes on this instance.

Parameters
----------
category : str
A cache category name, or ``"all"`` to iterate over every cache.
"""
if category == "all":
# Gather all attribute names from every category.
attr_names = set()
for names in self._CACHE_CATEGORIES.values():
attr_names.update(names)
else:
if category not in self._CACHE_CATEGORIES:
valid = sorted(["all"] + list(self._CACHE_CATEGORIES.keys()))
raise ValueError(
f"Unknown cache category {category!r}. "
f"Valid options: {', '.join(repr(v) for v in valid)}"
)
attr_names = set(self._CACHE_CATEGORIES[category])

for attr_name in sorted(attr_names):
obj = getattr(self, attr_name, None)
if obj is not None:
yield attr_name, obj

@staticmethod
def _estimate_cache_entry_nbytes(value):
"""Best-effort deep size estimate for a single cached value."""
try:
import xarray as xr

if isinstance(value, (xr.Dataset, xr.DataArray)):
return value.nbytes, "xarray.nbytes"
except ImportError: # pragma: no cover
pass

try:
import numpy as np

if isinstance(value, np.ndarray):
return value.nbytes, "numpy.nbytes"
except ImportError: # pragma: no cover
pass

try:
import dask.array

if isinstance(value, dask.array.Array):
return value.nbytes, "dask upper bound"
except ImportError: # pragma: no cover
pass

if isinstance(value, bytes):
return len(value), "bytes length"

return sys.getsizeof(value), "sys.getsizeof shallow"

@doc(
summary="""
Return information about in-memory caches held by this instance.
""",
returns="""
A dictionary keyed by cache attribute name. Each value is a
dictionary with keys ``'entries'``, ``'nbytes'``, ``'kind'``,
and ``'note'``.
""",
)
def cache_info(self) -> Dict[str, Dict[str, Any]]:
info: Dict[str, Dict[str, Any]] = {}

for attr_name, obj in self._iter_cache_attrs("all"):
if isinstance(obj, _lru_cache_wrapper):
ci = obj.cache_info()
info[attr_name] = {
"entries": ci.currsize,
"nbytes": 0,
"kind": "lru_cache",
"note": "size not estimated for lru_cache",
}
elif isinstance(obj, dict):
total_nbytes = 0
notes = set()
for v in obj.values():
nb, note = self._estimate_cache_entry_nbytes(v)
total_nbytes += nb
notes.add(note)
info[attr_name] = {
"entries": len(obj),
"nbytes": total_nbytes,
"kind": "dict",
"note": ", ".join(sorted(notes)) if notes else "empty",
}
else:
# Single-value caches (e.g. zarr groups stored as Optional).
nb, note = self._estimate_cache_entry_nbytes(obj)
info[attr_name] = {
"entries": 1,
"nbytes": nb,
"kind": "other",
"note": note,
}

return info

@doc(
summary="""
Clear in-memory caches to free memory.
""",
extended_summary="""
This is useful in long-running sessions (e.g., Jupyter notebooks
or Google Colab) where cached data accumulates and causes memory
pressure. Subsequent data access calls will repopulate the caches
on demand.
""",
parameters=dict(
category="""
The cache category to clear. Use ``"all"`` (default) to clear
every cache. Valid categories include ``"haplotypes"``,
``"cnv"``, ``"snp"``, ``"sample_metadata"``, ``"aim"``,
``"genome_features"``, ``"genome_sequence"``, and ``"base"``.
""",
),
)
def clear_cache(self, category: str = "all") -> None:
for attr_name, obj in self._iter_cache_attrs(category):
if isinstance(obj, _lru_cache_wrapper):
obj.cache_clear()
elif isinstance(obj, dict):
obj.clear()
else:
# Single-value caches — reset to None.
setattr(self, attr_name, None)

def _progress(self, iterable, desc=None, leave=False, **kwargs): # pragma: no cover
# Progress doesn't mix well with debug logging.
show_progress = self._show_progress and not self._debug
Expand Down
Loading
Loading