Skip to content

Commit ea9e24c

Browse files
authored
Merge pull request #1014 from Sharon-codes/issue-798-cohort-diversity-cache
feat: add results caching to cohort_diversity_stats
2 parents 17881d1 + 4b3ae3b commit ea9e24c

2 files changed

Lines changed: 106 additions & 13 deletions

File tree

malariagen_data/anopheles.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ def cohort_diversity_stats(
371371
) -> pd.Series:
372372
debug = self._log.debug
373373

374+
# Change this name if you ever change the behaviour of this function, to
375+
# invalidate any previously cached data.
376+
name = "cohort_diversity_stats_v1"
377+
374378
debug("process cohort parameter")
375379
cohort_query = None
376380
if isinstance(cohort, str):
@@ -391,28 +395,59 @@ def cohort_diversity_stats(
391395
else:
392396
raise TypeError(f"invalid cohort parameter: {cohort!r}")
393397

394-
debug("access allele counts")
395-
ac = self.snp_allele_counts(
398+
params = dict(
399+
cohort_label=cohort_label,
400+
cohort_query=cohort_query,
401+
cohort_size=cohort_size,
396402
region=region,
403+
min_cohort_size=min_cohort_size,
404+
max_cohort_size=max_cohort_size,
397405
site_mask=site_mask,
398406
site_class=site_class,
399-
sample_query=cohort_query,
400407
sample_sets=sample_sets,
401-
cohort_size=cohort_size,
402-
min_cohort_size=min_cohort_size,
403-
max_cohort_size=max_cohort_size,
404408
random_seed=random_seed,
409+
n_jack=n_jack,
410+
confidence_level=confidence_level,
405411
chunks=chunks,
406412
inline_array=inline_array,
407413
)
408414

409-
debug("compute diversity stats")
410-
stats = self._block_jackknife_cohort_diversity_stats(
411-
cohort_label=cohort_label,
412-
ac=ac,
413-
n_jack=n_jack,
414-
confidence_level=confidence_level,
415-
)
415+
# Try to retrieve results from the cache.
416+
try:
417+
results = self.results_cache_get(name=name, params=params)
418+
stats = {
419+
key: value.item()
420+
if isinstance(value, np.ndarray) and value.shape == ()
421+
else value
422+
for key, value in results.items()
423+
}
424+
425+
except CacheMiss:
426+
debug("access allele counts")
427+
ac = self.snp_allele_counts(
428+
region=region,
429+
site_mask=site_mask,
430+
site_class=site_class,
431+
sample_query=cohort_query,
432+
sample_sets=sample_sets,
433+
cohort_size=cohort_size,
434+
min_cohort_size=min_cohort_size,
435+
max_cohort_size=max_cohort_size,
436+
random_seed=random_seed,
437+
chunks=chunks,
438+
inline_array=inline_array,
439+
)
440+
441+
debug("compute diversity stats")
442+
stats = self._block_jackknife_cohort_diversity_stats(
443+
cohort_label=cohort_label,
444+
ac=ac,
445+
n_jack=n_jack,
446+
confidence_level=confidence_level,
447+
)
448+
449+
cache_results = {key: np.asarray(value) for key, value in stats.items()}
450+
self.results_cache_set(name=name, params=params, results=cache_results)
416451

417452
debug("compute some extra cohort variables")
418453
df_samples = self.sample_metadata(
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import random
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from malariagen_data import Ag3
7+
8+
9+
@pytest.fixture
10+
def ag3_sim_api(ag3_sim_fixture, tmp_path):
11+
data_path = ag3_sim_fixture.bucket_path.as_posix()
12+
return Ag3(
13+
url=data_path,
14+
public_url=data_path,
15+
pre=True,
16+
check_location=False,
17+
bokeh_output_notebook=False,
18+
results_cache=tmp_path.as_posix(),
19+
)
20+
21+
22+
def test_cohort_diversity_stats_uses_cache(ag3_sim_api, monkeypatch):
23+
api = ag3_sim_api
24+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
25+
sample_set = random.choice(all_sample_sets)
26+
df_samples = api.sample_metadata(sample_sets=[sample_set])
27+
cohort_sample_ids = df_samples["sample_id"].head(10).to_list()
28+
cohort_size = min(5, len(cohort_sample_ids))
29+
if cohort_size < 2:
30+
pytest.skip("not enough samples in simulated cohort")
31+
32+
params = dict(
33+
cohort=("cache_test", f"sample_id in {cohort_sample_ids!r}"),
34+
cohort_size=cohort_size,
35+
region=random.choice(api.contigs),
36+
sample_sets=[sample_set],
37+
random_seed=42,
38+
n_jack=10,
39+
confidence_level=0.95,
40+
)
41+
42+
stats_first = api.cohort_diversity_stats(**params)
43+
44+
def _unexpected_recompute(*args, **kwargs): # noqa: ARG001, ARG002
45+
raise AssertionError(
46+
"cohort_diversity_stats recomputed instead of loading from cache"
47+
)
48+
49+
monkeypatch.setattr(
50+
api, "_block_jackknife_cohort_diversity_stats", _unexpected_recompute
51+
)
52+
stats_second = api.cohort_diversity_stats(**params)
53+
54+
pd.testing.assert_series_equal(
55+
stats_first.sort_index(),
56+
stats_second.sort_index(),
57+
check_dtype=False,
58+
)

0 commit comments

Comments
 (0)