Skip to content

Commit 8fd8a08

Browse files
committed
Use simulated cache test for cohort_diversity_stats
1 parent f3e47f6 commit 8fd8a08

3 files changed

Lines changed: 58 additions & 112 deletions

File tree

malariagen_data/anopheles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,8 @@ def cohort_diversity_stats(
11231123
confidence_level=confidence_level,
11241124
)
11251125

1126-
self.results_cache_set(name=name, params=params, results=stats)
1126+
cache_results = {key: np.asarray(value) for key, value in stats.items()}
1127+
self.results_cache_set(name=name, params=params, results=cache_results)
11271128

11281129
debug("compute some extra cohort variables")
11291130
df_samples = self.sample_metadata(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import random
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from malariagen_data import Ag3
7+
8+
9+
@pytest.fixture
10+
def ag3_sim_api(ag3_sim_fixture, tmp_path):
11+
data_path = ag3_sim_fixture.bucket_path.as_posix()
12+
return Ag3(
13+
url=data_path,
14+
public_url=data_path,
15+
pre=True,
16+
check_location=False,
17+
bokeh_output_notebook=False,
18+
results_cache=tmp_path.as_posix(),
19+
)
20+
21+
22+
def test_cohort_diversity_stats_uses_cache(ag3_sim_api, monkeypatch):
23+
api = ag3_sim_api
24+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
25+
sample_set = random.choice(all_sample_sets)
26+
df_samples = api.sample_metadata(sample_sets=[sample_set])
27+
cohort_sample_ids = df_samples["sample_id"].head(10).to_list()
28+
cohort_size = min(5, len(cohort_sample_ids))
29+
if cohort_size < 2:
30+
pytest.skip("not enough samples in simulated cohort")
31+
32+
params = dict(
33+
cohort=("cache_test", f"sample_id in {cohort_sample_ids!r}"),
34+
cohort_size=cohort_size,
35+
region=random.choice(api.contigs),
36+
sample_sets=[sample_set],
37+
random_seed=42,
38+
n_jack=10,
39+
confidence_level=0.95,
40+
)
41+
42+
stats_first = api.cohort_diversity_stats(**params)
43+
44+
def _unexpected_recompute(*args, **kwargs): # noqa: ARG001, ARG002
45+
raise AssertionError(
46+
"cohort_diversity_stats recomputed instead of loading from cache"
47+
)
48+
49+
monkeypatch.setattr(api, "_block_jackknife_cohort_diversity_stats", _unexpected_recompute)
50+
stats_second = api.cohort_diversity_stats(**params)
51+
52+
pd.testing.assert_series_equal(
53+
stats_first.sort_index(),
54+
stats_second.sort_index(),
55+
check_dtype=False,
56+
)

tests/test_anopheles.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

0 commit comments

Comments
 (0)