|
| 1 | +import warnings |
1 | 2 | from functools import lru_cache |
2 | 3 | from typing import Any, Dict, List, Optional, Tuple, Union |
3 | 4 |
|
|
38 | 39 | from .genome_sequence import AnophelesGenomeSequenceData |
39 | 40 | from .sample_metadata import AnophelesSampleMetadata |
40 | 41 |
|
| 42 | +# Maximum number of entries kept in the per-instance _cache_locate_site_class |
| 43 | +# dict. The natural ceiling is n_contigs × n_site_classes (≈ 45 for Ag3), so |
| 44 | +# 64 gives comfortable headroom without allowing unbounded growth. |
| 45 | +_LOCATE_SITE_CLASS_CACHE_MAXSIZE = 64 |
| 46 | + |
41 | 47 |
|
42 | 48 | class AnophelesSnpData( |
43 | 49 | AnophelesSampleMetadata, AnophelesGenomeFeaturesData, AnophelesGenomeSequenceData |
@@ -68,6 +74,14 @@ def __init__( |
68 | 74 | self._cache_site_annotations = None |
69 | 75 | self._cache_locate_site_class: Dict = dict() |
70 | 76 |
|
| 77 | + # Create the SNP-calls cache as a per-instance lru_cache wrapping the |
| 78 | + # bound method. Storing it on the instance (rather than using a |
| 79 | + # class-level @lru_cache decorator) means: |
| 80 | + # 1. `self` is not part of the cache key, so old instances are freed |
| 81 | + # normally when the caller drops their reference. |
| 82 | + # 2. Different instances have independent, non-interfering caches. |
| 83 | + self._cached_snp_calls = lru_cache(maxsize=2)(self._raw_snp_calls) |
| 84 | + |
71 | 85 | @property |
72 | 86 | def _site_filters_analysis(self) -> Optional[str]: |
73 | 87 | if self._site_filters_analysis_override: |
@@ -928,6 +942,13 @@ def _locate_site_class( |
928 | 942 |
|
929 | 943 | self._cache_locate_site_class[cache_key] = loc_ann |
930 | 944 |
|
| 945 | + # Evict the oldest entry when the cache exceeds its size limit. |
| 946 | + # Plain dicts preserve insertion order (Python 3.7+), so the first |
| 947 | + # key is always the oldest. |
| 948 | + while len(self._cache_locate_site_class) > _LOCATE_SITE_CLASS_CACHE_MAXSIZE: |
| 949 | + oldest = next(iter(self._cache_locate_site_class)) |
| 950 | + del self._cache_locate_site_class[oldest] |
| 951 | + |
931 | 952 | return loc_ann |
932 | 953 |
|
933 | 954 | def _snp_calls_for_contig( |
@@ -1088,16 +1109,7 @@ def snp_calls( |
1088 | 1109 | chunks=chunks, |
1089 | 1110 | ) |
1090 | 1111 |
|
1091 | | - # Here we cache to improve performance for functions which |
1092 | | - # access SNP calls more than once. For example, this currently |
1093 | | - # happens during access of biallelic SNP calls, because a |
1094 | | - # first computation of allele counts is required, before |
1095 | | - # then using that to filter SNP calls. |
1096 | | - # |
1097 | | - # We only cache up to 2 items because otherwise we can see |
1098 | | - # high memory usage. |
1099 | | - @lru_cache(maxsize=2) |
1100 | | - def _cached_snp_calls( |
| 1112 | + def _raw_snp_calls( |
1101 | 1113 | self, |
1102 | 1114 | *, |
1103 | 1115 | regions: Tuple[Region, ...], |
@@ -1253,6 +1265,12 @@ def _snp_calls( |
1253 | 1265 | if max_cohort_size is not None: |
1254 | 1266 | n_samples = ds.sizes["samples"] |
1255 | 1267 | if n_samples > max_cohort_size: |
| 1268 | + warnings.warn( |
| 1269 | + f"Cohort downsampled from {n_samples} to {max_cohort_size} " |
| 1270 | + "samples. Set max_cohort_size=None to disable downsampling.", |
| 1271 | + UserWarning, |
| 1272 | + stacklevel=2, |
| 1273 | + ) |
1256 | 1274 | rng = np.random.default_rng(seed=random_seed) |
1257 | 1275 | loc_downsample = rng.choice( |
1258 | 1276 | n_samples, size=max_cohort_size, replace=False |
|
0 commit comments