Skip to content

Commit 7a4a409

Browse files
committed
feat: add results caching to cohort_diversity_stats
1 parent 27ac08c commit 7a4a409

2 files changed

Lines changed: 158 additions & 13 deletions

File tree

malariagen_data/anopheles.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,10 @@ def cohort_diversity_stats(
10481048
) -> pd.Series:
10491049
debug = self._log.debug
10501050

1051+
# Change this name if you ever change the behaviour of this function, to
1052+
# invalidate any previously cached data.
1053+
name = "cohort_diversity_stats_v1"
1054+
10511055
debug("process cohort parameter")
10521056
cohort_query = None
10531057
if isinstance(cohort, str):
@@ -1068,28 +1072,58 @@ def cohort_diversity_stats(
10681072
else:
10691073
raise TypeError(r"invalid cohort parameter: {cohort!r}")
10701074

1071-
debug("access allele counts")
1072-
ac = self.snp_allele_counts(
1075+
params = dict(
1076+
cohort_label=cohort_label,
1077+
cohort_query=cohort_query,
1078+
cohort_size=cohort_size,
10731079
region=region,
1080+
min_cohort_size=min_cohort_size,
1081+
max_cohort_size=max_cohort_size,
10741082
site_mask=site_mask,
10751083
site_class=site_class,
1076-
sample_query=cohort_query,
10771084
sample_sets=sample_sets,
1078-
cohort_size=cohort_size,
1079-
min_cohort_size=min_cohort_size,
1080-
max_cohort_size=max_cohort_size,
10811085
random_seed=random_seed,
1086+
n_jack=n_jack,
1087+
confidence_level=confidence_level,
10821088
chunks=chunks,
10831089
inline_array=inline_array,
10841090
)
10851091

1086-
debug("compute diversity stats")
1087-
stats = self._block_jackknife_cohort_diversity_stats(
1088-
cohort_label=cohort_label,
1089-
ac=ac,
1090-
n_jack=n_jack,
1091-
confidence_level=confidence_level,
1092-
)
1092+
# Try to retrieve results from the cache.
1093+
try:
1094+
results = self.results_cache_get(name=name, params=params)
1095+
stats = {
1096+
key: value.item()
1097+
if isinstance(value, np.ndarray) and value.shape == ()
1098+
else value
1099+
for key, value in results.items()
1100+
}
1101+
1102+
except CacheMiss:
1103+
debug("access allele counts")
1104+
ac = self.snp_allele_counts(
1105+
region=region,
1106+
site_mask=site_mask,
1107+
site_class=site_class,
1108+
sample_query=cohort_query,
1109+
sample_sets=sample_sets,
1110+
cohort_size=cohort_size,
1111+
min_cohort_size=min_cohort_size,
1112+
max_cohort_size=max_cohort_size,
1113+
random_seed=random_seed,
1114+
chunks=chunks,
1115+
inline_array=inline_array,
1116+
)
1117+
1118+
debug("compute diversity stats")
1119+
stats = self._block_jackknife_cohort_diversity_stats(
1120+
cohort_label=cohort_label,
1121+
ac=ac,
1122+
n_jack=n_jack,
1123+
confidence_level=confidence_level,
1124+
)
1125+
1126+
self.results_cache_set(name=name, params=params, results=stats)
10931127

10941128
debug("compute some extra cohort variables")
10951129
df_samples = self.sample_metadata(

tests/test_anopheles.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from types import SimpleNamespace
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from malariagen_data.anopheles import AnophelesDataResource
7+
from malariagen_data.util import CacheMiss
8+
9+
10+
class DummyAnophelesDataResource(AnophelesDataResource):
11+
@property
12+
def _xpehh_gwss_cache_name(self):
13+
return "dummy_xpehh_gwss_cache"
14+
15+
@property
16+
def _ihs_gwss_cache_name(self):
17+
return "dummy_ihs_gwss_cache"
18+
19+
@property
20+
def _roh_hmm_cache_name(self):
21+
return "dummy_roh_hmm_cache"
22+
23+
def __init__(self):
24+
self._log = SimpleNamespace(debug=lambda *args, **kwargs: None)
25+
self._cache = {}
26+
self.block_jackknife_calls = 0
27+
28+
def sample_metadata(self, sample_sets=None, sample_query=None): # noqa: ARG002
29+
data = pd.DataFrame(
30+
{
31+
"sample_id": ["s1", "s2", "s3"],
32+
"cohort_admin1_year": ["cohort_a", "cohort_a", "cohort_b"],
33+
"taxon": ["gambiae", "gambiae", "coluzzii"],
34+
"year": [2020, 2020, 2021],
35+
"month": [1, 2, 3],
36+
"country": ["Ghana", "Ghana", "Benin"],
37+
"admin1_iso": ["GH-AA", "GH-AA", "BJ-AK"],
38+
"admin1_name": ["Accra", "Accra", "Atlantique"],
39+
"admin2_name": ["Accra", "Accra", "Abomey-Calavi"],
40+
"longitude": [-0.2, -0.4, 2.3],
41+
"latitude": [5.6, 5.8, 6.4],
42+
}
43+
)
44+
if sample_query is not None:
45+
return data.query(sample_query).reset_index(drop=True)
46+
return data
47+
48+
def snp_allele_counts(self, **kwargs): # noqa: ARG002
49+
return np.array([[2, 0], [1, 1], [0, 2]])
50+
51+
def _block_jackknife_cohort_diversity_stats(
52+
self, *, cohort_label, ac, n_jack, confidence_level # noqa: ARG002
53+
):
54+
self.block_jackknife_calls += 1
55+
return {
56+
"cohort": cohort_label,
57+
"theta_pi": 0.123,
58+
"theta_pi_estimate": 0.124,
59+
"theta_pi_bias": 0.001,
60+
"theta_pi_std_err": 0.01,
61+
"theta_pi_ci_err": 0.02,
62+
"theta_pi_ci_low": 0.1,
63+
"theta_pi_ci_upp": 0.14,
64+
"theta_w": 0.111,
65+
"theta_w_estimate": 0.112,
66+
"theta_w_bias": 0.001,
67+
"theta_w_std_err": 0.01,
68+
"theta_w_ci_err": 0.02,
69+
"theta_w_ci_low": 0.09,
70+
"theta_w_ci_upp": 0.13,
71+
"tajima_d": 0.3,
72+
"tajima_d_estimate": 0.31,
73+
"tajima_d_bias": 0.01,
74+
"tajima_d_std_err": 0.05,
75+
"tajima_d_ci_err": 0.1,
76+
"tajima_d_ci_low": 0.2,
77+
"tajima_d_ci_upp": 0.4,
78+
}
79+
80+
def results_cache_get(self, *, name, params):
81+
key = (name, repr(params))
82+
if key not in self._cache:
83+
raise CacheMiss
84+
return self._cache[key]
85+
86+
def results_cache_set(self, *, name, params, results):
87+
key = (name, repr(params))
88+
self._cache[key] = results
89+
90+
91+
def test_cohort_diversity_stats_uses_cache():
92+
api = DummyAnophelesDataResource()
93+
94+
stats1 = api.cohort_diversity_stats(
95+
cohort="cohort_a",
96+
cohort_size=2,
97+
region="2L",
98+
n_jack=10,
99+
confidence_level=0.95,
100+
)
101+
stats2 = api.cohort_diversity_stats(
102+
cohort="cohort_a",
103+
cohort_size=2,
104+
region="2L",
105+
n_jack=10,
106+
confidence_level=0.95,
107+
)
108+
109+
assert api.block_jackknife_calls == 1
110+
pd.testing.assert_series_equal(stats1.sort_index(), stats2.sort_index())
111+

0 commit comments

Comments
 (0)