Skip to content

Commit d021f51

Browse files
committed
Moved the tests for the plotting functions
1 parent 9837220 commit d021f51

1 file changed

Lines changed: 318 additions & 0 deletions

File tree

tests/anoph/test_frq.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import random
2+
3+
import pytest
4+
from pytest_cases import parametrize_with_cases
5+
import plotly.graph_objects as go # type: ignore
6+
7+
from malariagen_data import af1 as _af1
8+
from malariagen_data import ag3 as _ag3
9+
from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis
10+
11+
from .test_snp_frq import random_transcript
12+
13+
14+
@pytest.fixture
15+
def ag3_sim_api(ag3_sim_fixture):
16+
return AnophelesSnpFrequencyAnalysis(
17+
url=ag3_sim_fixture.url,
18+
config_path=_ag3.CONFIG_PATH,
19+
major_version_number=_ag3.MAJOR_VERSION_NUMBER,
20+
major_version_path=_ag3.MAJOR_VERSION_PATH,
21+
pre=True,
22+
aim_metadata_dtype={
23+
"aim_species_fraction_arab": "float64",
24+
"aim_species_fraction_colu": "float64",
25+
"aim_species_fraction_colu_no2l": "float64",
26+
"aim_species_gambcolu_arabiensis": object,
27+
"aim_species_gambiae_coluzzii": object,
28+
"aim_species": object,
29+
},
30+
gff_gene_type="gene",
31+
gff_gene_name_attribute="Name",
32+
gff_default_attributes=("ID", "Parent", "Name", "description"),
33+
default_site_mask="gamb_colu_arab",
34+
results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
35+
taxon_colors=_ag3.TAXON_COLORS,
36+
)
37+
38+
39+
@pytest.fixture
40+
def af1_sim_api(af1_sim_fixture):
41+
return AnophelesSnpFrequencyAnalysis(
42+
url=af1_sim_fixture.url,
43+
config_path=_af1.CONFIG_PATH,
44+
major_version_number=_af1.MAJOR_VERSION_NUMBER,
45+
major_version_path=_af1.MAJOR_VERSION_PATH,
46+
pre=False,
47+
gff_gene_type="protein_coding_gene",
48+
gff_gene_name_attribute="Note",
49+
gff_default_attributes=("ID", "Parent", "Note", "description"),
50+
default_site_mask="funestus",
51+
results_cache=af1_sim_fixture.results_cache_path.as_posix(),
52+
taxon_colors=_af1.TAXON_COLORS,
53+
)
54+
55+
56+
# N.B., here we use pytest_cases to parametrize tests. Each
57+
# function whose name begins with "case_" defines a set of
58+
# inputs to the test functions. See the documentation for
59+
# pytest_cases for more information, e.g.:
60+
#
61+
# https://smarie.github.io/python-pytest-cases/#basic-usage
62+
#
63+
# We use this approach here because we want to use fixtures
64+
# as test parameters, which is otherwise hard to do with
65+
# pytest alone.
66+
67+
68+
def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
69+
return ag3_sim_fixture, ag3_sim_api
70+
71+
72+
def case_af1_sim(af1_sim_fixture, af1_sim_api):
73+
return af1_sim_fixture, af1_sim_api
74+
75+
76+
@parametrize_with_cases("fixture,api", cases=".")
77+
def test_plot_frequencies_heatmap(
78+
fixture,
79+
api: AnophelesSnpFrequencyAnalysis,
80+
):
81+
# Pick test parameters at random.
82+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
83+
sample_sets = random.choice(all_sample_sets)
84+
site_mask = random.choice(api.site_mask_ids + (None,))
85+
min_cohort_size = random.randint(0, 2)
86+
transcript = random_transcript(api=api).name
87+
cohorts = random.choice(
88+
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
89+
)
90+
91+
# Set up call params.
92+
params = dict(
93+
transcript=transcript,
94+
cohorts=cohorts,
95+
min_cohort_size=min_cohort_size,
96+
site_mask=site_mask,
97+
sample_sets=sample_sets,
98+
)
99+
100+
# Test SNP allele frequencies.
101+
df_snp = api.snp_allele_frequencies(**params)
102+
fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None)
103+
assert isinstance(fig, go.Figure)
104+
105+
# Test amino acid change allele frequencies.
106+
df_aa = api.aa_allele_frequencies(**params)
107+
fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None)
108+
assert isinstance(fig, go.Figure)
109+
110+
# Test max_len behaviour.
111+
with pytest.raises(ValueError):
112+
api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1)
113+
114+
# Test index parameter - if None, should use dataframe index.
115+
fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None)
116+
# Not unique.
117+
with pytest.raises(ValueError):
118+
api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None)
119+
120+
121+
@parametrize_with_cases("fixture,api", cases=".")
122+
def test_plot_frequencies_time_series(
123+
fixture,
124+
api: AnophelesSnpFrequencyAnalysis,
125+
):
126+
# Pick test parameters at random.
127+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
128+
sample_sets = random.choice(all_sample_sets)
129+
site_mask = random.choice(api.site_mask_ids + (None,))
130+
min_cohort_size = random.randint(0, 2)
131+
transcript = random_transcript(api=api).name
132+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
133+
period_by = random.choice(["year", "quarter", "month"])
134+
135+
# Compute SNP frequencies.
136+
ds = api.snp_allele_frequencies_advanced(
137+
transcript=transcript,
138+
area_by=area_by,
139+
period_by=period_by,
140+
sample_sets=sample_sets,
141+
min_cohort_size=min_cohort_size,
142+
site_mask=site_mask,
143+
)
144+
145+
# Trim things down a bit for speed.
146+
ds = ds.isel(variants=slice(0, 100))
147+
148+
# Plot.
149+
fig = api.plot_frequencies_time_series(ds, show=False)
150+
151+
# Test.
152+
assert isinstance(fig, go.Figure)
153+
154+
# Compute amino acid change frequencies.
155+
ds = api.aa_allele_frequencies_advanced(
156+
transcript=transcript,
157+
area_by=area_by,
158+
period_by=period_by,
159+
sample_sets=sample_sets,
160+
min_cohort_size=min_cohort_size,
161+
)
162+
163+
# Trim things down a bit for speed.
164+
ds = ds.isel(variants=slice(0, 100))
165+
166+
# Plot.
167+
fig = api.plot_frequencies_time_series(ds, show=False)
168+
169+
# Test.
170+
assert isinstance(fig, go.Figure)
171+
172+
173+
@parametrize_with_cases("fixture,api", cases=".")
174+
def test_plot_frequencies_time_series_with_taxa(
175+
fixture,
176+
api: AnophelesSnpFrequencyAnalysis,
177+
):
178+
# Pick test parameters at random.
179+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
180+
sample_sets = random.choice(all_sample_sets)
181+
site_mask = random.choice(api.site_mask_ids + (None,))
182+
transcript = random_transcript(api=api).name
183+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
184+
period_by = random.choice(["year", "quarter", "month"])
185+
186+
# Pick a random taxon and taxa from valid taxa.
187+
sample_sets_taxa = (
188+
api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
189+
)
190+
taxon = random.choice(sample_sets_taxa)
191+
taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))
192+
193+
# Compute SNP frequencies.
194+
ds = api.snp_allele_frequencies_advanced(
195+
transcript=transcript,
196+
area_by=area_by,
197+
period_by=period_by,
198+
sample_sets=sample_sets,
199+
min_cohort_size=1, # Don't exclude any samples.
200+
site_mask=site_mask,
201+
)
202+
203+
# Trim things down a bit for speed.
204+
ds = ds.isel(variants=slice(0, 100))
205+
206+
# Plot with taxon.
207+
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)
208+
209+
# Test taxon plot.
210+
assert isinstance(fig, go.Figure)
211+
212+
# Plot with taxa.
213+
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)
214+
215+
# Test taxa plot.
216+
assert isinstance(fig, go.Figure)
217+
218+
219+
@parametrize_with_cases("fixture,api", cases=".")
220+
def test_plot_frequencies_time_series_with_areas(
221+
fixture,
222+
api: AnophelesSnpFrequencyAnalysis,
223+
):
224+
# Pick test parameters at random.
225+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
226+
sample_sets = random.choice(all_sample_sets)
227+
site_mask = random.choice(api.site_mask_ids + (None,))
228+
transcript = random_transcript(api=api).name
229+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
230+
period_by = random.choice(["year", "quarter", "month"])
231+
232+
# Compute SNP frequencies.
233+
ds = api.snp_allele_frequencies_advanced(
234+
transcript=transcript,
235+
area_by=area_by,
236+
period_by=period_by,
237+
sample_sets=sample_sets,
238+
min_cohort_size=1, # Don't exclude any samples.
239+
site_mask=site_mask,
240+
)
241+
242+
# Trim things down a bit for speed.
243+
ds = ds.isel(variants=slice(0, 100))
244+
245+
# Extract cohorts into a DataFrame.
246+
cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
247+
df_cohorts = ds[cohort_vars].to_dataframe()
248+
249+
# Pick a random area and areas from valid areas.
250+
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
251+
area = random.choice(cohorts_areas)
252+
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
253+
254+
# Plot with area.
255+
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)
256+
257+
# Test areas plot.
258+
assert isinstance(fig, go.Figure)
259+
260+
# Plot with areas.
261+
fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)
262+
263+
# Test area plot.
264+
assert isinstance(fig, go.Figure)
265+
266+
267+
@parametrize_with_cases("fixture,api", cases=".")
268+
def test_plot_frequencies_interactive_map(
269+
fixture,
270+
api: AnophelesSnpFrequencyAnalysis,
271+
):
272+
import ipywidgets # type: ignore
273+
274+
# Pick test parameters at random.
275+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
276+
sample_sets = random.choice(all_sample_sets)
277+
site_mask = random.choice(api.site_mask_ids + (None,))
278+
min_cohort_size = random.randint(0, 2)
279+
transcript = random_transcript(api=api).name
280+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
281+
period_by = random.choice(["year", "quarter", "month"])
282+
283+
# Compute SNP frequencies.
284+
ds = api.snp_allele_frequencies_advanced(
285+
transcript=transcript,
286+
area_by=area_by,
287+
period_by=period_by,
288+
sample_sets=sample_sets,
289+
min_cohort_size=min_cohort_size,
290+
site_mask=site_mask,
291+
)
292+
293+
# Trim things down a bit for speed.
294+
ds = ds.isel(variants=slice(0, 100))
295+
296+
# Plot.
297+
fig = api.plot_frequencies_interactive_map(ds)
298+
299+
# Test.
300+
assert isinstance(fig, ipywidgets.Widget)
301+
302+
# Compute amino acid change frequencies.
303+
ds = api.aa_allele_frequencies_advanced(
304+
transcript=transcript,
305+
area_by=area_by,
306+
period_by=period_by,
307+
sample_sets=sample_sets,
308+
min_cohort_size=min_cohort_size,
309+
)
310+
311+
# Trim things down a bit for speed.
312+
ds = ds.isel(variants=slice(0, 100))
313+
314+
# Plot.
315+
fig = api.plot_frequencies_interactive_map(ds)
316+
317+
# Test.
318+
assert isinstance(fig, ipywidgets.Widget)

0 commit comments

Comments
 (0)