Skip to content

Commit 43c0dfa

Browse files
authored
Merge pull request #1027 from Sharon-codes/issue-919-cnv-variant-query
test(cnv): stabilise simulated high-variance sampling
2 parents c269768 + 9c1192e commit 43c0dfa

2 files changed

Lines changed: 86 additions & 5 deletions

File tree

tests/anoph/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def simulate_exons(
334334
# keep things simple for now.
335335
if strand == "-":
336336
# Take exons in reverse order.
337-
exons == exons[::-1]
337+
exons = exons[::-1]
338338
for exon_ix, exon in enumerate(exons):
339339
first_exon = exon_ix == 0
340340
last_exon = exon_ix == len(exons) - 1
@@ -646,8 +646,8 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes, rng):
646646
# - sample_is_high_variance [1D array] [bool] [True or False for n_samples]
647647
# - samples [1D array] [str]
648648

649-
# Get a random probability for a sample being high variance, between 0 and 1.
650-
p_variance = rng.random()
649+
# Keep high variance sample prevalence stable for deterministic tests.
650+
p_variance = 0.1
651651

652652
# Open a zarr at the specified path.
653653
root = zarr.open(zarr_path, mode="w")
@@ -862,8 +862,8 @@ def simulate_cnv_discordant_read_calls(
862862
# - sample_is_high_variance [1D array] [bool] [True or False for n_samples]
863863
# - samples [1D array] [str for n_samples]
864864

865-
# Get a random probability for a sample being high variance, between 0 and 1.
866-
p_variance = rng.random()
865+
# Keep high variance sample prevalence stable for deterministic tests.
866+
p_variance = 0.1
867867

868868
# Get a random probability for choosing allele 1, between 0 and 1.
869869
p_allele = rng.random()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pandas as pd
5+
import zarr
6+
7+
from .conftest import (
8+
Gff3Simulator,
9+
simulate_cnv_discordant_read_calls,
10+
simulate_cnv_hmm,
11+
)
12+
13+
14+
def _write_sample_metadata(path: Path, n_samples: int = 100) -> None:
15+
df_samples = pd.DataFrame({"sample_id": [f"S{i:04d}" for i in range(n_samples)]})
16+
df_samples.to_csv(path, index=False)
17+
18+
19+
def test_simulate_cnv_hmm_limits_high_variance_fraction(tmp_path):
20+
zarr_path = tmp_path / "cnv_hmm.zarr"
21+
metadata_path = tmp_path / "samples.csv"
22+
_write_sample_metadata(metadata_path)
23+
24+
simulate_cnv_hmm(
25+
zarr_path=zarr_path,
26+
metadata_path=metadata_path,
27+
contigs=("2L",),
28+
contig_sizes={"2L": 10_000},
29+
rng=np.random.default_rng(0),
30+
)
31+
32+
root = zarr.open(zarr_path, mode="r")
33+
high_variance_fraction = np.mean(root["sample_is_high_variance"][:])
34+
assert high_variance_fraction < 0.3
35+
36+
37+
def test_simulate_cnv_discordant_read_calls_limits_high_variance_fraction(tmp_path):
38+
zarr_path = tmp_path / "cnv_discordant.zarr"
39+
metadata_path = tmp_path / "samples.csv"
40+
_write_sample_metadata(metadata_path)
41+
42+
simulate_cnv_discordant_read_calls(
43+
zarr_path=zarr_path,
44+
metadata_path=metadata_path,
45+
contigs=("2L",),
46+
contig_sizes={"2L": 10_000},
47+
rng=np.random.default_rng(0),
48+
)
49+
50+
root = zarr.open(zarr_path, mode="r")
51+
high_variance_fraction = np.mean(root["sample_is_high_variance"][:])
52+
assert high_variance_fraction < 0.3
53+
54+
55+
def test_simulate_exons_on_minus_strand_reverses_feature_order():
56+
sim = Gff3Simulator(
57+
contig_sizes={"2L": 10_000},
58+
rng=np.random.default_rng(0),
59+
n_exons_low=3,
60+
n_exons_high=3,
61+
intron_size_low=10,
62+
intron_size_high=10,
63+
exon_size_low=100,
64+
exon_size_high=100,
65+
)
66+
rows = list(
67+
sim.simulate_exons(
68+
contig="2L",
69+
strand="-",
70+
gene_ix=0,
71+
transcript_ix=0,
72+
transcript_id="transcript-2L-0-0",
73+
transcript_start=1,
74+
transcript_end=1_000,
75+
)
76+
)
77+
cds_and_utrs = [
78+
row for row in rows if row[2] in {sim.utr5_type, sim.utr3_type, sim.cds_type}
79+
]
80+
starts = [row[3] for row in cds_and_utrs]
81+
assert starts == sorted(starts, reverse=True)

0 commit comments

Comments
 (0)