Skip to content

Commit 7b9fcaf

Browse files
authored
Merge pull request #911 from adilraza99/GH907-n-jack-validation
Guard against invalid jackknife block sizing when n_jack exceeds site count
2 parents 7c49d0a + b0519b4 commit 7b9fcaf

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

malariagen_data/anoph/fst.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,14 @@ def average_fst(
404404
)
405405

406406
# Calculate block length for jackknife.
407-
n_sites = ac1.shape[0] # number of sites
408-
block_length = n_sites // n_jack # number of sites in each block
407+
n_sites = ac1.shape[0]
408+
block_length = n_sites // n_jack
409+
410+
if block_length < 1:
411+
raise ValueError(
412+
f"Not enough sites ({n_sites}) for {n_jack} jackknife blocks. "
413+
"Choose a larger region or reduce n_jack."
414+
)
409415

410416
# Calculate average Fst.
411417
fst, se, _, _ = allel.blockwise_hudson_fst(ac1, ac2, blen=block_length)

tests/anoph/test_fst.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,27 @@ def test_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAnalysis):
190190
api.average_fst(**fst_params)
191191

192192

193+
@parametrize_with_cases("fixture,api", cases=".")
194+
def test_average_fst_region_too_small(fixture, api: AnophelesFstAnalysis):
195+
"""ValueError should be raised when block_length == 0 (n_jack > n_sites)."""
196+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
197+
all_countries = api.sample_metadata()["country"].dropna().unique().tolist()
198+
countries = random.sample(all_countries, 2)
199+
cohort1_query = f"country == {countries[0]!r}"
200+
cohort2_query = f"country == {countries[1]!r}"
201+
fst_params = dict(
202+
region=random.choice(api.contigs),
203+
sample_sets=all_sample_sets,
204+
cohort1_query=cohort1_query,
205+
cohort2_query=cohort2_query,
206+
site_mask=random.choice(api.site_mask_ids),
207+
min_cohort_size=1,
208+
n_jack=1_000_000, # deliberately exceeds available sites
209+
)
210+
with pytest.raises(ValueError, match="Not enough sites"):
211+
api.average_fst(**fst_params)
212+
213+
193214
def check_pairwise_average_fst(api: AnophelesFstAnalysis, fst_params):
194215
# Run main function under test.
195216
fst_df = api.pairwise_average_fst(**fst_params)

0 commit comments

Comments
 (0)