Skip to content

Commit f2c4481

Browse files
committed
Merge remote-tracking branch 'upstream/master' into refactor/xpehh-analysis
2 parents 03b9c86 + e48d75b commit f2c4481

5 files changed

Lines changed: 177 additions & 6 deletions

File tree

malariagen_data/anoph/fst.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Tuple, Optional
23

34
import numpy as np
@@ -43,6 +44,8 @@ def _fst_gwss(
4344
inline_array,
4445
chunks,
4546
clip_min,
47+
min_snps_threshold,
48+
window_adjustment_factor,
4649
):
4750
# Compute allele counts.
4851
ac1 = self.snp_allele_counts(
@@ -81,6 +84,24 @@ def _fst_gwss(
8184
chunks=chunks,
8285
).compute()
8386

87+
n_snps = len(pos)
88+
if n_snps < min_snps_threshold:
89+
raise ValueError(
90+
f"Too few SNP sites ({n_snps}) available for Fst GWSS. "
91+
f"At least {min_snps_threshold} sites are required. "
92+
"Try a larger genomic region or different site selection criteria."
93+
)
94+
if window_size >= n_snps:
95+
adjusted_window_size = max(1, n_snps // window_adjustment_factor)
96+
warnings.warn(
97+
f"window_size ({window_size}) is >= the number of SNP sites "
98+
f"available ({n_snps}); automatically adjusting window_size to "
99+
f"{adjusted_window_size} (= {n_snps} // {window_adjustment_factor}).",
100+
UserWarning,
101+
stacklevel=2,
102+
)
103+
window_size = adjusted_window_size
104+
84105
with self._spinner(desc="Compute Fst"):
85106
with np.errstate(divide="ignore", invalid="ignore"):
86107
fst = allel.moving_hudson_fst(ac1, ac2, size=window_size)
@@ -96,8 +117,23 @@ def _fst_gwss(
96117
@doc(
97118
summary="""
98119
Run a Fst genome-wide scan to investigate genetic differentiation
99-
between two cohorts.
120+
between two cohorts. If window_size is >= the number of available
121+
SNP sites, a UserWarning is issued and window_size is automatically
122+
adjusted to number_of_snps // window_adjustment_factor. A ValueError
123+
is raised if the number of available SNP sites is below
124+
min_snps_threshold.
100125
""",
126+
parameters=dict(
127+
min_snps_threshold="""
128+
Minimum number of SNP sites required. If fewer sites are
129+
available a ValueError is raised.
130+
""",
131+
window_adjustment_factor="""
132+
If window_size is >= the number of available SNP sites,
133+
window_size is automatically set to
134+
number_of_snps // window_adjustment_factor.
135+
""",
136+
),
101137
returns=dict(
102138
x="An array containing the window centre point genomic positions",
103139
fst="An array with Fst statistic values for each window.",
@@ -123,6 +159,8 @@ def fst_gwss(
123159
inline_array: base_params.inline_array = base_params.inline_array_default,
124160
chunks: base_params.chunks = base_params.native_chunks,
125161
clip_min: fst_params.clip_min = 0.0,
162+
min_snps_threshold: fst_params.min_snps_threshold = 1000,
163+
window_adjustment_factor: fst_params.window_adjustment_factor = 10,
126164
) -> Tuple[np.ndarray, np.ndarray]:
127165
# Change this name if you ever change the behaviour of this function, to
128166
# invalidate any previously cached data.
@@ -147,7 +185,13 @@ def fst_gwss(
147185
results = self.results_cache_get(name=name, params=params)
148186

149187
except CacheMiss:
150-
results = self._fst_gwss(**params, inline_array=inline_array, chunks=chunks)
188+
results = self._fst_gwss(
189+
**params,
190+
inline_array=inline_array,
191+
chunks=chunks,
192+
min_snps_threshold=min_snps_threshold,
193+
window_adjustment_factor=window_adjustment_factor,
194+
)
151195
self.results_cache_set(name=name, params=params, results=results)
152196

153197
x = results["x"]

malariagen_data/anoph/fst_params.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
""",
3535
]
3636

37+
min_snps_threshold: TypeAlias = Annotated[
38+
int,
39+
"""
40+
Minimum number of SNP sites required for the Fst GWSS computation. If
41+
fewer sites are available, a ValueError is raised.
42+
""",
43+
]
44+
45+
window_adjustment_factor: TypeAlias = Annotated[
46+
int,
47+
"""
48+
If window_size is >= the number of available SNP sites, the window_size
49+
is automatically adjusted to number_of_snps // window_adjustment_factor.
50+
""",
51+
]
52+
3753
annotation: TypeAlias = Annotated[
3854
Optional[Literal["standard error", "Z score", "lower triangle"]],
3955
"""

malariagen_data/anoph/snp_data.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from collections import OrderedDict
23
from functools import lru_cache
34
from typing import Any, Dict, List, Optional, Tuple, Union
45

@@ -75,7 +76,9 @@ def __init__(
7576
base_params.site_mask, zarr.hierarchy.Group
7677
] = dict()
7778
self._cache_site_annotations: Optional[zarr.hierarchy.Group] = None
78-
self._cache_locate_site_class: Dict[Tuple[Any, ...], np.ndarray] = dict()
79+
self._cache_locate_site_class: OrderedDict[
80+
Tuple[Any, ...], np.ndarray
81+
] = OrderedDict()
7982

8083
# Create the SNP-calls cache as a per-instance lru_cache wrapping the
8184
# bound method. Storing it on the instance (rather than using a
@@ -878,6 +881,9 @@ def _locate_site_class(
878881

879882
try:
880883
loc_ann = self._cache_locate_site_class[cache_key]
884+
# Promote to most-recently-used so the LRU eviction below
885+
# always removes the *least*-recently-used entry.
886+
self._cache_locate_site_class.move_to_end(cache_key)
881887

882888
except KeyError as exc:
883889
# Access site annotations data.
@@ -1023,9 +1029,10 @@ def _locate_site_class(
10231029

10241030
self._cache_locate_site_class[cache_key] = loc_ann
10251031

1026-
# Evict the oldest entry when the cache exceeds its size limit.
1027-
# Plain dicts preserve insertion order (Python 3.7+), so the first
1028-
# key is always the oldest.
1032+
# Evict the least-recently-used entry when the cache exceeds its
1033+
# size limit. Because the cache is an OrderedDict and both hits
1034+
# (move_to_end above) and inserts append to the right, the first
1035+
# key is always the *least*-recently-used entry.
10291036
while len(self._cache_locate_site_class) > _LOCATE_SITE_CLASS_CACHE_MAXSIZE:
10301037
oldest = next(iter(self._cache_locate_site_class))
10311038
del self._cache_locate_site_class[oldest]

tests/anoph/test_fst.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,52 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis):
139139
assert isinstance(fig, bokeh.models.GridPlot)
140140

141141

142+
@parametrize_with_cases("fixture,api", cases=".")
143+
def test_fst_gwss_window_size_too_large(fixture, api: AnophelesFstAnalysis):
144+
# When window_size exceeds available SNPs, a UserWarning must be issued and
145+
# the function must still return a valid result using the adjusted window_size.
146+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
147+
all_countries = api.sample_metadata()["country"].dropna().unique().tolist()
148+
countries = np.random.choice(all_countries, size=2, replace=False).tolist()
149+
cohort1_query = f"country == {countries[0]!r}"
150+
cohort2_query = f"country == {countries[1]!r}"
151+
with pytest.warns(UserWarning, match="window_size"):
152+
x, fst = api.fst_gwss(
153+
contig=str(np.random.choice(api.contigs)),
154+
sample_sets=all_sample_sets,
155+
cohort1_query=cohort1_query,
156+
cohort2_query=cohort2_query,
157+
site_mask=str(np.random.choice(api.site_mask_ids)),
158+
window_size=10_000_000, # far larger than any fixture SNP count
159+
min_cohort_size=1,
160+
)
161+
assert isinstance(x, np.ndarray)
162+
assert isinstance(fst, np.ndarray)
163+
assert len(x) > 0
164+
assert x.shape == fst.shape
165+
166+
167+
@parametrize_with_cases("fixture,api", cases=".")
168+
def test_fst_gwss_too_few_snps(fixture, api: AnophelesFstAnalysis):
169+
# When min_snps_threshold exceeds available SNPs, a ValueError must be raised.
170+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
171+
all_countries = api.sample_metadata()["country"].dropna().unique().tolist()
172+
countries = np.random.choice(all_countries, size=2, replace=False).tolist()
173+
cohort1_query = f"country == {countries[0]!r}"
174+
cohort2_query = f"country == {countries[1]!r}"
175+
with pytest.raises(ValueError, match="Too few SNP sites"):
176+
api.fst_gwss(
177+
contig=str(np.random.choice(api.contigs)),
178+
sample_sets=all_sample_sets,
179+
cohort1_query=cohort1_query,
180+
cohort2_query=cohort2_query,
181+
site_mask=str(np.random.choice(api.site_mask_ids)),
182+
window_size=100,
183+
min_cohort_size=1,
184+
min_snps_threshold=10_000_000, # far larger than any fixture SNP count (~28k-70k)
185+
)
186+
187+
142188
@parametrize_with_cases("fixture,api", cases=".")
143189
def test_average_fst(fixture, api: AnophelesFstAnalysis):
144190
# Set up test parameters.

tests/anoph/test_snp_data.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,64 @@ def test_locate_site_class_cache_is_bounded(ag3_sim_api: AnophelesSnpData):
973973
assert len(ag3_sim_api._cache_locate_site_class) <= _LOCATE_SITE_CLASS_CACHE_MAXSIZE
974974

975975

976+
def test_locate_site_class_cache_lru_eviction(ag3_sim_api: AnophelesSnpData):
977+
"""Verify true LRU semantics: recently *accessed* entries survive eviction,
978+
while least-recently-used entries are evicted first."""
979+
from collections import OrderedDict
980+
981+
from malariagen_data.anoph.snp_data import _LOCATE_SITE_CLASS_CACHE_MAXSIZE
982+
983+
cache = ag3_sim_api._cache_locate_site_class
984+
985+
# Start from a clean cache.
986+
cache.clear()
987+
assert isinstance(cache, OrderedDict)
988+
989+
maxsize = _LOCATE_SITE_CLASS_CACHE_MAXSIZE # 64
990+
991+
# --- Phase 1: fill the cache to exactly maxsize ---
992+
dummy = np.array([True, False])
993+
for i in range(maxsize):
994+
key = (f"contig_{i}", f"mask_{i}", f"class_{i}")
995+
cache[key] = dummy
996+
assert len(cache) == maxsize
997+
998+
# Remember the first key inserted (the oldest / least-recently-used).
999+
first_key = ("contig_0", "mask_0", "class_0")
1000+
assert first_key in cache
1001+
1002+
# --- Phase 2: simulate an access (LRU promotion) on the first key ---
1003+
# move_to_end makes it the most-recently-used entry.
1004+
cache.move_to_end(first_key)
1005+
1006+
# Insert one more entry, exceeding maxsize.
1007+
overflow_key = ("overflow", "mask", "class")
1008+
cache[overflow_key] = dummy
1009+
1010+
# Evict to maintain the bound (same logic as _locate_site_class).
1011+
while len(cache) > maxsize:
1012+
oldest = next(iter(cache))
1013+
del cache[oldest]
1014+
1015+
# The first key should STILL be present because it was promoted.
1016+
assert (
1017+
first_key in cache
1018+
), "LRU promotion via move_to_end must keep recently accessed entries alive"
1019+
1020+
# The second key ("contig_1", ...) — which was never re-accessed —
1021+
# should have been evicted as the new least-recently-used entry.
1022+
second_key = ("contig_1", "mask_1", "class_1")
1023+
assert (
1024+
second_key not in cache
1025+
), "The least-recently-used entry should be evicted when cache exceeds maxsize"
1026+
1027+
# The overflow key should be present (it was just inserted).
1028+
assert overflow_key in cache
1029+
1030+
# Cache size must remain bounded.
1031+
assert len(cache) == maxsize
1032+
1033+
9761034
def test_snp_calls_cache_is_per_instance(ag3_sim_api: AnophelesSnpData):
9771035
"""_cached_snp_calls must be a per-instance lru_cache, not a class-level one.
9781036

0 commit comments

Comments
 (0)