Skip to content

Commit 06be149

Browse files
committed
Tried to move random_transcript to conftest. WIP.
1 parent 5a57346 commit 06be149

3 files changed

Lines changed: 13 additions & 253 deletions

File tree

tests/anoph/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,14 @@ def contigs(self) -> Tuple[str, ...]:
10131013
def random_contig(self):
10141014
return choice(self.contigs)
10151015

1016+
def random_transcript_id(self):
1017+
df_transcripts = self.genome_features.query("type == 'mRNA'")
1018+
transcript_ids = [
1019+
t.split(";")[0].split("=")[1] for t in df_transcripts.loc[:, "attributes"]
1020+
]
1021+
transcript_id = choice(transcript_ids)
1022+
return transcript_id
1023+
10161024
def random_region_str(self, region_size=None):
10171025
contig = self.random_contig()
10181026
contig_size = self.contig_sizes[contig]

tests/anoph/test_frq.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from malariagen_data import ag3 as _ag3
99
from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis
1010

11-
from .test_snp_frq import random_transcript
12-
1311

1412
@pytest.fixture
1513
def ag3_sim_api(ag3_sim_fixture):
@@ -83,7 +81,7 @@ def test_plot_frequencies_heatmap(
8381
sample_sets = random.choice(all_sample_sets)
8482
site_mask = random.choice(api.site_mask_ids + (None,))
8583
min_cohort_size = random.randint(0, 2)
86-
transcript = random_transcript(api=api).name
84+
transcript = fixture.random_transcript_id()
8785
cohorts = random.choice(
8886
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
8987
)
@@ -128,7 +126,7 @@ def test_plot_frequencies_time_series(
128126
sample_sets = random.choice(all_sample_sets)
129127
site_mask = random.choice(api.site_mask_ids + (None,))
130128
min_cohort_size = random.randint(0, 2)
131-
transcript = random_transcript(api=api).name
129+
transcript = fixture.random_transcript_id()
132130
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
133131
period_by = random.choice(["year", "quarter", "month"])
134132

@@ -179,7 +177,7 @@ def test_plot_frequencies_time_series_with_taxa(
179177
all_sample_sets = api.sample_sets()["sample_set"].to_list()
180178
sample_sets = random.choice(all_sample_sets)
181179
site_mask = random.choice(api.site_mask_ids + (None,))
182-
transcript = random_transcript(api=api).name
180+
transcript = fixture.random_transcript_id()
183181
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
184182
period_by = random.choice(["year", "quarter", "month"])
185183

@@ -225,7 +223,7 @@ def test_plot_frequencies_time_series_with_areas(
225223
all_sample_sets = api.sample_sets()["sample_set"].to_list()
226224
sample_sets = random.choice(all_sample_sets)
227225
site_mask = random.choice(api.site_mask_ids + (None,))
228-
transcript = random_transcript(api=api).name
226+
transcript = fixture.random_transcript_id()
229227
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
230228
period_by = random.choice(["year", "quarter", "month"])
231229

@@ -276,7 +274,7 @@ def test_plot_frequencies_interactive_map(
276274
sample_sets = random.choice(all_sample_sets)
277275
site_mask = random.choice(api.site_mask_ids + (None,))
278276
min_cohort_size = random.randint(0, 2)
279-
transcript = random_transcript(api=api).name
277+
transcript = fixture.random_transcript_id()
280278
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
281279
period_by = random.choice(["year", "quarter", "month"])
282280

tests/anoph/test_snp_frq.py

Lines changed: 0 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pytest_cases import parametrize_with_cases
88
import xarray as xr
99
from numpy.testing import assert_allclose, assert_array_equal
10-
import plotly.graph_objects as go # type: ignore
1110

1211
from malariagen_data import af1 as _af1
1312
from malariagen_data import ag3 as _ag3
@@ -1429,248 +1428,3 @@ def test_allele_frequencies_advanced_with_dup_samples(
14291428
api=api,
14301429
sample_sets=sample_sets,
14311430
)
1432-
1433-
1434-
@parametrize_with_cases("fixture,api", cases=".")
1435-
def test_plot_frequencies_heatmap(
1436-
fixture,
1437-
api: AnophelesSnpFrequencyAnalysis,
1438-
):
1439-
# Pick test parameters at random.
1440-
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1441-
sample_sets = random.choice(all_sample_sets)
1442-
site_mask = random.choice(api.site_mask_ids + (None,))
1443-
min_cohort_size = random.randint(0, 2)
1444-
transcript = random_transcript(api=api).name
1445-
cohorts = random.choice(
1446-
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
1447-
)
1448-
1449-
# Set up call params.
1450-
params = dict(
1451-
transcript=transcript,
1452-
cohorts=cohorts,
1453-
min_cohort_size=min_cohort_size,
1454-
site_mask=site_mask,
1455-
sample_sets=sample_sets,
1456-
)
1457-
1458-
# Test SNP allele frequencies.
1459-
df_snp = api.snp_allele_frequencies(**params)
1460-
fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None)
1461-
assert isinstance(fig, go.Figure)
1462-
1463-
# Test amino acid change allele frequencies.
1464-
df_aa = api.aa_allele_frequencies(**params)
1465-
fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None)
1466-
assert isinstance(fig, go.Figure)
1467-
1468-
# Test max_len behaviour.
1469-
with pytest.raises(ValueError):
1470-
api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1)
1471-
1472-
# Test index parameter - if None, should use dataframe index.
1473-
fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None)
1474-
# Not unique.
1475-
with pytest.raises(ValueError):
1476-
api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None)
1477-
1478-
1479-
@parametrize_with_cases("fixture,api", cases=".")
1480-
def test_plot_frequencies_time_series(
1481-
fixture,
1482-
api: AnophelesSnpFrequencyAnalysis,
1483-
):
1484-
# Pick test parameters at random.
1485-
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1486-
sample_sets = random.choice(all_sample_sets)
1487-
site_mask = random.choice(api.site_mask_ids + (None,))
1488-
min_cohort_size = random.randint(0, 2)
1489-
transcript = random_transcript(api=api).name
1490-
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1491-
period_by = random.choice(["year", "quarter", "month"])
1492-
1493-
# Compute SNP frequencies.
1494-
ds = api.snp_allele_frequencies_advanced(
1495-
transcript=transcript,
1496-
area_by=area_by,
1497-
period_by=period_by,
1498-
sample_sets=sample_sets,
1499-
min_cohort_size=min_cohort_size,
1500-
site_mask=site_mask,
1501-
)
1502-
1503-
# Trim things down a bit for speed.
1504-
ds = ds.isel(variants=slice(0, 100))
1505-
1506-
# Plot.
1507-
fig = api.plot_frequencies_time_series(ds, show=False)
1508-
1509-
# Test.
1510-
assert isinstance(fig, go.Figure)
1511-
1512-
# Compute amino acid change frequencies.
1513-
ds = api.aa_allele_frequencies_advanced(
1514-
transcript=transcript,
1515-
area_by=area_by,
1516-
period_by=period_by,
1517-
sample_sets=sample_sets,
1518-
min_cohort_size=min_cohort_size,
1519-
)
1520-
1521-
# Trim things down a bit for speed.
1522-
ds = ds.isel(variants=slice(0, 100))
1523-
1524-
# Plot.
1525-
fig = api.plot_frequencies_time_series(ds, show=False)
1526-
1527-
# Test.
1528-
assert isinstance(fig, go.Figure)
1529-
1530-
1531-
@parametrize_with_cases("fixture,api", cases=".")
1532-
def test_plot_frequencies_time_series_with_taxa(
1533-
fixture,
1534-
api: AnophelesSnpFrequencyAnalysis,
1535-
):
1536-
# Pick test parameters at random.
1537-
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1538-
sample_sets = random.choice(all_sample_sets)
1539-
site_mask = random.choice(api.site_mask_ids + (None,))
1540-
transcript = random_transcript(api=api).name
1541-
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1542-
period_by = random.choice(["year", "quarter", "month"])
1543-
1544-
# Pick a random taxon and taxa from valid taxa.
1545-
sample_sets_taxa = (
1546-
api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
1547-
)
1548-
taxon = random.choice(sample_sets_taxa)
1549-
taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))
1550-
1551-
# Compute SNP frequencies.
1552-
ds = api.snp_allele_frequencies_advanced(
1553-
transcript=transcript,
1554-
area_by=area_by,
1555-
period_by=period_by,
1556-
sample_sets=sample_sets,
1557-
min_cohort_size=1, # Don't exclude any samples.
1558-
site_mask=site_mask,
1559-
)
1560-
1561-
# Trim things down a bit for speed.
1562-
ds = ds.isel(variants=slice(0, 100))
1563-
1564-
# Plot with taxon.
1565-
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)
1566-
1567-
# Test taxon plot.
1568-
assert isinstance(fig, go.Figure)
1569-
1570-
# Plot with taxa.
1571-
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)
1572-
1573-
# Test taxa plot.
1574-
assert isinstance(fig, go.Figure)
1575-
1576-
1577-
@parametrize_with_cases("fixture,api", cases=".")
1578-
def test_plot_frequencies_time_series_with_areas(
1579-
fixture,
1580-
api: AnophelesSnpFrequencyAnalysis,
1581-
):
1582-
# Pick test parameters at random.
1583-
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1584-
sample_sets = random.choice(all_sample_sets)
1585-
site_mask = random.choice(api.site_mask_ids + (None,))
1586-
transcript = random_transcript(api=api).name
1587-
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1588-
period_by = random.choice(["year", "quarter", "month"])
1589-
1590-
# Compute SNP frequencies.
1591-
ds = api.snp_allele_frequencies_advanced(
1592-
transcript=transcript,
1593-
area_by=area_by,
1594-
period_by=period_by,
1595-
sample_sets=sample_sets,
1596-
min_cohort_size=1, # Don't exclude any samples.
1597-
site_mask=site_mask,
1598-
)
1599-
1600-
# Trim things down a bit for speed.
1601-
ds = ds.isel(variants=slice(0, 100))
1602-
1603-
# Extract cohorts into a DataFrame.
1604-
cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
1605-
df_cohorts = ds[cohort_vars].to_dataframe()
1606-
1607-
# Pick a random area and areas from valid areas.
1608-
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
1609-
area = random.choice(cohorts_areas)
1610-
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
1611-
1612-
# Plot with area.
1613-
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)
1614-
1615-
# Test areas plot.
1616-
assert isinstance(fig, go.Figure)
1617-
1618-
# Plot with areas.
1619-
fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)
1620-
1621-
# Test area plot.
1622-
assert isinstance(fig, go.Figure)
1623-
1624-
1625-
@parametrize_with_cases("fixture,api", cases=".")
1626-
def test_plot_frequencies_interactive_map(
1627-
fixture,
1628-
api: AnophelesSnpFrequencyAnalysis,
1629-
):
1630-
import ipywidgets # type: ignore
1631-
1632-
# Pick test parameters at random.
1633-
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1634-
sample_sets = random.choice(all_sample_sets)
1635-
site_mask = random.choice(api.site_mask_ids + (None,))
1636-
min_cohort_size = random.randint(0, 2)
1637-
transcript = random_transcript(api=api).name
1638-
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1639-
period_by = random.choice(["year", "quarter", "month"])
1640-
1641-
# Compute SNP frequencies.
1642-
ds = api.snp_allele_frequencies_advanced(
1643-
transcript=transcript,
1644-
area_by=area_by,
1645-
period_by=period_by,
1646-
sample_sets=sample_sets,
1647-
min_cohort_size=min_cohort_size,
1648-
site_mask=site_mask,
1649-
)
1650-
1651-
# Trim things down a bit for speed.
1652-
ds = ds.isel(variants=slice(0, 100))
1653-
1654-
# Plot.
1655-
fig = api.plot_frequencies_interactive_map(ds)
1656-
1657-
# Test.
1658-
assert isinstance(fig, ipywidgets.Widget)
1659-
1660-
# Compute amino acid change frequencies.
1661-
ds = api.aa_allele_frequencies_advanced(
1662-
transcript=transcript,
1663-
area_by=area_by,
1664-
period_by=period_by,
1665-
sample_sets=sample_sets,
1666-
min_cohort_size=min_cohort_size,
1667-
)
1668-
1669-
# Trim things down a bit for speed.
1670-
ds = ds.isel(variants=slice(0, 100))
1671-
1672-
# Plot.
1673-
fig = api.plot_frequencies_interactive_map(ds)
1674-
1675-
# Test.
1676-
assert isinstance(fig, ipywidgets.Widget)

0 commit comments

Comments
 (0)