Skip to content

Commit e6107e9

Browse files
authored
Merge branch 'master' into GH-1054-add-vcf-export
2 parents 183c2e3 + 43c0dfa commit e6107e9

4 files changed

Lines changed: 135 additions & 7 deletions

File tree

malariagen_data/anoph/fst.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def plot_fst_gwss(
360360
)
361361
def average_fst(
362362
self,
363-
region: base_params.region,
363+
region: base_params.regions,
364364
cohort1_query: base_params.sample_query,
365365
cohort2_query: base_params.sample_query,
366366
sample_query_options: Optional[base_params.sample_query_options] = None,
@@ -435,7 +435,7 @@ def average_fst(
435435
)
436436
def pairwise_average_fst(
437437
self,
438-
region: base_params.region,
438+
region: base_params.regions,
439439
cohorts: base_params.cohorts,
440440
sample_sets: Optional[base_params.sample_sets] = None,
441441
sample_query: Optional[base_params.sample_query] = None,

tests/anoph/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def simulate_exons(
334334
# keep things simple for now.
335335
if strand == "-":
336336
# Take exons in reverse order.
337-
exons == exons[::-1]
337+
exons = exons[::-1]
338338
for exon_ix, exon in enumerate(exons):
339339
first_exon = exon_ix == 0
340340
last_exon = exon_ix == len(exons) - 1
@@ -646,8 +646,8 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes, rng):
646646
# - sample_is_high_variance [1D array] [bool] [True or False for n_samples]
647647
# - samples [1D array] [str]
648648

649-
# Get a random probability for a sample being high variance, between 0 and 1.
650-
p_variance = rng.random()
649+
# Keep high variance sample prevalence stable for deterministic tests.
650+
p_variance = 0.1
651651

652652
# Open a zarr at the specified path.
653653
root = zarr.open(zarr_path, mode="w")
@@ -862,8 +862,8 @@ def simulate_cnv_discordant_read_calls(
862862
# - sample_is_high_variance [1D array] [bool] [True or False for n_samples]
863863
# - samples [1D array] [str for n_samples]
864864

865-
# Get a random probability for a sample being high variance, between 0 and 1.
866-
p_variance = rng.random()
865+
# Keep high variance sample prevalence stable for deterministic tests.
866+
p_variance = 0.1
867867

868868
# Get a random probability for choosing allele 1, between 0 and 1.
869869
p_allele = rng.random()

tests/anoph/test_fst.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,50 @@ def test_pairwise_average_fst_with_bad_cohorts(fixture, api: AnophelesFstAnalysi
368368
# Run function under test.
369369
with pytest.raises(ValueError):
370370
api.pairwise_average_fst(**fst_params)
371+
372+
373+
@parametrize_with_cases("fixture,api", cases=".")
374+
def test_average_fst_with_list_of_regions(fixture, api: AnophelesFstAnalysis):
375+
# Set up test parameters.
376+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
377+
all_countries = api.sample_metadata()["country"].dropna().unique().tolist()
378+
countries = random.sample(all_countries, 2)
379+
cohort1_query = f"country == {countries[0]!r}"
380+
cohort2_query = f"country == {countries[1]!r}"
381+
fst_params = dict(
382+
region=random.sample(api.contigs, 2),
383+
sample_sets=all_sample_sets,
384+
cohort1_query=cohort1_query,
385+
cohort2_query=cohort2_query,
386+
site_mask=random.choice(api.site_mask_ids),
387+
min_cohort_size=1,
388+
n_jack=random.randint(10, 200),
389+
)
390+
391+
# Run function under test.
392+
fst, se = api.average_fst(**fst_params)
393+
394+
# Checks.
395+
assert isinstance(fst, float)
396+
assert isinstance(se, float)
397+
assert 0 <= fst <= 1
398+
assert 0 <= se <= 1
399+
400+
401+
@parametrize_with_cases("fixture,api", cases=".")
402+
def test_pairwise_average_fst_with_list_of_regions(fixture, api: AnophelesFstAnalysis):
403+
# Set up test parameters.
404+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
405+
region = random.sample(api.contigs, 2)
406+
site_mask = random.choice(api.site_mask_ids)
407+
fst_params = dict(
408+
region=region,
409+
cohorts="country",
410+
sample_sets=all_sample_sets,
411+
site_mask=site_mask,
412+
min_cohort_size=1,
413+
n_jack=random.randint(10, 200),
414+
)
415+
416+
# Run checks.
417+
check_pairwise_average_fst(api=api, fst_params=fst_params)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pandas as pd
5+
import zarr
6+
7+
from .conftest import (
8+
Gff3Simulator,
9+
simulate_cnv_discordant_read_calls,
10+
simulate_cnv_hmm,
11+
)
12+
13+
14+
def _write_sample_metadata(path: Path, n_samples: int = 100) -> None:
15+
df_samples = pd.DataFrame({"sample_id": [f"S{i:04d}" for i in range(n_samples)]})
16+
df_samples.to_csv(path, index=False)
17+
18+
19+
def test_simulate_cnv_hmm_limits_high_variance_fraction(tmp_path):
20+
zarr_path = tmp_path / "cnv_hmm.zarr"
21+
metadata_path = tmp_path / "samples.csv"
22+
_write_sample_metadata(metadata_path)
23+
24+
simulate_cnv_hmm(
25+
zarr_path=zarr_path,
26+
metadata_path=metadata_path,
27+
contigs=("2L",),
28+
contig_sizes={"2L": 10_000},
29+
rng=np.random.default_rng(0),
30+
)
31+
32+
root = zarr.open(zarr_path, mode="r")
33+
high_variance_fraction = np.mean(root["sample_is_high_variance"][:])
34+
assert high_variance_fraction < 0.3
35+
36+
37+
def test_simulate_cnv_discordant_read_calls_limits_high_variance_fraction(tmp_path):
38+
zarr_path = tmp_path / "cnv_discordant.zarr"
39+
metadata_path = tmp_path / "samples.csv"
40+
_write_sample_metadata(metadata_path)
41+
42+
simulate_cnv_discordant_read_calls(
43+
zarr_path=zarr_path,
44+
metadata_path=metadata_path,
45+
contigs=("2L",),
46+
contig_sizes={"2L": 10_000},
47+
rng=np.random.default_rng(0),
48+
)
49+
50+
root = zarr.open(zarr_path, mode="r")
51+
high_variance_fraction = np.mean(root["sample_is_high_variance"][:])
52+
assert high_variance_fraction < 0.3
53+
54+
55+
def test_simulate_exons_on_minus_strand_reverses_feature_order():
56+
sim = Gff3Simulator(
57+
contig_sizes={"2L": 10_000},
58+
rng=np.random.default_rng(0),
59+
n_exons_low=3,
60+
n_exons_high=3,
61+
intron_size_low=10,
62+
intron_size_high=10,
63+
exon_size_low=100,
64+
exon_size_high=100,
65+
)
66+
rows = list(
67+
sim.simulate_exons(
68+
contig="2L",
69+
strand="-",
70+
gene_ix=0,
71+
transcript_ix=0,
72+
transcript_id="transcript-2L-0-0",
73+
transcript_start=1,
74+
transcript_end=1_000,
75+
)
76+
)
77+
cds_and_utrs = [
78+
row for row in rows if row[2] in {sim.utr5_type, sim.utr3_type, sim.cds_type}
79+
]
80+
starts = [row[3] for row in cds_and_utrs]
81+
assert starts == sorted(starts, reverse=True)

0 commit comments

Comments
 (0)