|
1 | 1 | import os |
| 2 | +import sys |
2 | 3 | import warnings |
3 | 4 |
|
4 | 5 | import json |
| 6 | +from functools import _lru_cache_wrapper |
5 | 7 | from contextlib import nullcontext |
6 | 8 | from datetime import date |
7 | 9 | from pathlib import Path |
@@ -188,6 +190,181 @@ def __init__( |
188 | 190 | if results_cache is not None: |
189 | 191 | self._results_cache = Path(results_cache).expanduser().resolve() |
190 | 192 |
|
| 193 | + # Mapping from category names to cache attribute name prefixes/patterns. |
| 194 | + _CACHE_CATEGORIES: Dict[str, Tuple[str, ...]] = { |
| 195 | + "base": ( |
| 196 | + "_cache_releases", |
| 197 | + "_cache_available_releases", |
| 198 | + "_cache_sample_sets", |
| 199 | + "_cache_available_sample_sets", |
| 200 | + "_cache_sample_set_to_release", |
| 201 | + "_cache_sample_set_to_study", |
| 202 | + "_cache_sample_set_to_study_info", |
| 203 | + "_cache_sample_set_to_terms_of_use_info", |
| 204 | + "_cache_files", |
| 205 | + ), |
| 206 | + "sample_metadata": ( |
| 207 | + "_cache_sample_metadata", |
| 208 | + "_cache_cohorts", |
| 209 | + "_cache_cohort_geometries", |
| 210 | + ), |
| 211 | + "genome_features": ("_cache_genome_features",), |
| 212 | + "genome_sequence": ("_cache_genome",), |
| 213 | + "snp": ( |
| 214 | + "_cache_snp_sites", |
| 215 | + "_cache_snp_genotypes", |
| 216 | + "_cache_site_filters", |
| 217 | + "_cache_site_annotations", |
| 218 | + "_cache_locate_site_class", |
| 219 | + "_cached_snp_calls", |
| 220 | + ), |
| 221 | + "haplotypes": ( |
| 222 | + "_cache_haplotypes", |
| 223 | + "_cache_haplotype_sites", |
| 224 | + ), |
| 225 | + "cnv": ( |
| 226 | + "_cache_cnv_hmm", |
| 227 | + "_cache_cnv_coverage_calls", |
| 228 | + "_cache_cnv_discordant_read_calls", |
| 229 | + ), |
| 230 | + "aim": ("_cache_aim_variants",), |
| 231 | + } |
| 232 | + |
| 233 | + def _iter_cache_attrs(self, category="all"): |
| 234 | + """Yield (attr_name, obj) pairs for cache attributes on this instance. |
| 235 | +
|
| 236 | + Parameters |
| 237 | + ---------- |
| 238 | + category : str |
| 239 | + A cache category name, or ``"all"`` to iterate over every cache. |
| 240 | + """ |
| 241 | + if category == "all": |
| 242 | + # Gather all attribute names from every category. |
| 243 | + attr_names = set() |
| 244 | + for names in self._CACHE_CATEGORIES.values(): |
| 245 | + attr_names.update(names) |
| 246 | + else: |
| 247 | + if category not in self._CACHE_CATEGORIES: |
| 248 | + valid = sorted(["all"] + list(self._CACHE_CATEGORIES.keys())) |
| 249 | + raise ValueError( |
| 250 | + f"Unknown cache category {category!r}. " |
| 251 | + f"Valid options: {', '.join(repr(v) for v in valid)}" |
| 252 | + ) |
| 253 | + attr_names = set(self._CACHE_CATEGORIES[category]) |
| 254 | + |
| 255 | + for attr_name in sorted(attr_names): |
| 256 | + obj = getattr(self, attr_name, None) |
| 257 | + if obj is not None: |
| 258 | + yield attr_name, obj |
| 259 | + |
| 260 | + @staticmethod |
| 261 | + def _estimate_cache_entry_nbytes(value): |
| 262 | + """Best-effort deep size estimate for a single cached value.""" |
| 263 | + try: |
| 264 | + import xarray as xr |
| 265 | + |
| 266 | + if isinstance(value, (xr.Dataset, xr.DataArray)): |
| 267 | + return value.nbytes, "xarray.nbytes" |
| 268 | + except ImportError: # pragma: no cover |
| 269 | + pass |
| 270 | + |
| 271 | + try: |
| 272 | + import numpy as np |
| 273 | + |
| 274 | + if isinstance(value, np.ndarray): |
| 275 | + return value.nbytes, "numpy.nbytes" |
| 276 | + except ImportError: # pragma: no cover |
| 277 | + pass |
| 278 | + |
| 279 | + try: |
| 280 | + import dask.array |
| 281 | + |
| 282 | + if isinstance(value, dask.array.Array): |
| 283 | + return value.nbytes, "dask upper bound" |
| 284 | + except ImportError: # pragma: no cover |
| 285 | + pass |
| 286 | + |
| 287 | + if isinstance(value, bytes): |
| 288 | + return len(value), "bytes length" |
| 289 | + |
| 290 | + return sys.getsizeof(value), "sys.getsizeof shallow" |
| 291 | + |
| 292 | + @doc( |
| 293 | + summary=""" |
| 294 | + Return information about in-memory caches held by this instance. |
| 295 | + """, |
| 296 | + returns=""" |
| 297 | + A dictionary keyed by cache attribute name. Each value is a |
| 298 | + dictionary with keys ``'entries'``, ``'nbytes'``, ``'kind'``, |
| 299 | + and ``'note'``. |
| 300 | + """, |
| 301 | + ) |
| 302 | + def cache_info(self) -> Dict[str, Dict[str, Any]]: |
| 303 | + info: Dict[str, Dict[str, Any]] = {} |
| 304 | + |
| 305 | + for attr_name, obj in self._iter_cache_attrs("all"): |
| 306 | + if isinstance(obj, _lru_cache_wrapper): |
| 307 | + ci = obj.cache_info() |
| 308 | + info[attr_name] = { |
| 309 | + "entries": ci.currsize, |
| 310 | + "nbytes": 0, |
| 311 | + "kind": "lru_cache", |
| 312 | + "note": "size not estimated for lru_cache", |
| 313 | + } |
| 314 | + elif isinstance(obj, dict): |
| 315 | + total_nbytes = 0 |
| 316 | + notes = set() |
| 317 | + for v in obj.values(): |
| 318 | + nb, note = self._estimate_cache_entry_nbytes(v) |
| 319 | + total_nbytes += nb |
| 320 | + notes.add(note) |
| 321 | + info[attr_name] = { |
| 322 | + "entries": len(obj), |
| 323 | + "nbytes": total_nbytes, |
| 324 | + "kind": "dict", |
| 325 | + "note": ", ".join(sorted(notes)) if notes else "empty", |
| 326 | + } |
| 327 | + else: |
| 328 | + # Single-value caches (e.g. zarr groups stored as Optional). |
| 329 | + nb, note = self._estimate_cache_entry_nbytes(obj) |
| 330 | + info[attr_name] = { |
| 331 | + "entries": 1, |
| 332 | + "nbytes": nb, |
| 333 | + "kind": "other", |
| 334 | + "note": note, |
| 335 | + } |
| 336 | + |
| 337 | + return info |
| 338 | + |
| 339 | + @doc( |
| 340 | + summary=""" |
| 341 | + Clear in-memory caches to free memory. |
| 342 | + """, |
| 343 | + extended_summary=""" |
| 344 | + This is useful in long-running sessions (e.g., Jupyter notebooks |
| 345 | + or Google Colab) where cached data accumulates and causes memory |
| 346 | + pressure. Subsequent data access calls will repopulate the caches |
| 347 | + on demand. |
| 348 | + """, |
| 349 | + parameters=dict( |
| 350 | + category=""" |
| 351 | + The cache category to clear. Use ``"all"`` (default) to clear |
| 352 | + every cache. Valid categories include ``"haplotypes"``, |
| 353 | + ``"cnv"``, ``"snp"``, ``"sample_metadata"``, ``"aim"``, |
| 354 | + ``"genome_features"``, ``"genome_sequence"``, and ``"base"``. |
| 355 | + """, |
| 356 | + ), |
| 357 | + ) |
| 358 | + def clear_cache(self, category: str = "all") -> None: |
| 359 | + for attr_name, obj in self._iter_cache_attrs(category): |
| 360 | + if isinstance(obj, _lru_cache_wrapper): |
| 361 | + obj.cache_clear() |
| 362 | + elif isinstance(obj, dict): |
| 363 | + obj.clear() |
| 364 | + else: |
| 365 | + # Single-value caches — reset to None. |
| 366 | + setattr(self, attr_name, None) |
| 367 | + |
191 | 368 | def _progress(self, iterable, desc=None, leave=False, **kwargs): # pragma: no cover |
192 | 369 | # Progress doesn't mix well with debug logging. |
193 | 370 | show_progress = self._show_progress and not self._debug |
|
0 commit comments