Skip to content

Commit 40c77d5

Browse files
committed
Merge upstream/master into GH1083-warn-case-mismatch-query (resolve conflicts)
2 parents c12e397 + eae786e commit 40c77d5

12 files changed

Lines changed: 547 additions & 35 deletions

File tree

malariagen_data/anoph/sample_metadata.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
import re
34
from itertools import cycle
45
from typing import (
@@ -82,6 +83,8 @@ def __init__(
8283

8384
# Initialize cache attributes.
8485
self._cache_sample_metadata: Dict = dict()
86+
self._cache_cohorts: Dict = dict()
87+
self._cache_cohort_geometries: Dict = dict()
8588

8689
def _metadata_paths(
8790
self,
@@ -1522,7 +1525,11 @@ def _setup_cohort_queries(
15221525
A cohort set name. Accepted values are:
15231526
"admin1_month", "admin1_quarter", "admin1_year",
15241527
"admin2_month", "admin2_quarter", "admin2_year".
1525-
"""
1528+
""",
1529+
query="""
1530+
An optional pandas query string to filter the resulting
1531+
dataframe, e.g., "country == 'Burkina Faso'".
1532+
""",
15261533
),
15271534
returns="""A dataframe of cohort data, one row per cohort. There are up to 18 columns:
15281535
`cohort_id` is the identifier of the cohort,
@@ -1549,20 +1556,98 @@ def _setup_cohort_queries(
15491556
def cohorts(
15501557
self,
15511558
cohort_set: base_params.cohorts,
1559+
query: Optional[str] = None,
15521560
) -> pd.DataFrame:
1553-
major_version_path = self._major_version_path
1561+
valid_cohort_sets = {
1562+
"admin1_month",
1563+
"admin1_quarter",
1564+
"admin1_year",
1565+
"admin2_month",
1566+
"admin2_quarter",
1567+
"admin2_year",
1568+
}
1569+
if cohort_set not in valid_cohort_sets:
1570+
raise ValueError(
1571+
f"{cohort_set!r} is not a valid cohort set. "
1572+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1573+
)
1574+
1575+
cohorts_analysis = self._cohorts_analysis
1576+
1577+
# Cache to avoid repeated reads.
1578+
cache_key = (cohorts_analysis, cohort_set)
1579+
try:
1580+
df_cohorts = self._cache_cohorts[cache_key]
1581+
except KeyError:
1582+
major_version_path = self._major_version_path
1583+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1584+
1585+
with self.open_file(path) as f:
1586+
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1587+
1588+
# Ensure all column names are lower case.
1589+
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
1590+
1591+
self._cache_cohorts[cache_key] = df_cohorts
1592+
1593+
if query is not None:
1594+
df_cohorts = df_cohorts.query(query)
1595+
df_cohorts = df_cohorts.reset_index(drop=True)
1596+
1597+
return df_cohorts.copy()
1598+
1599+
@_check_types
1600+
@doc(
1601+
summary="""
1602+
Read GeoJSON geometry data for a specific cohort set,
1603+
providing boundary geometries for each cohort.
1604+
""",
1605+
parameters=dict(
1606+
cohort_set="""
1607+
A cohort set name. Accepted values are:
1608+
"admin1_month", "admin1_quarter", "admin1_year",
1609+
"admin2_month", "admin2_quarter", "admin2_year".
1610+
""",
1611+
),
1612+
returns="""
1613+
A dict containing the parsed GeoJSON FeatureCollection,
1614+
with boundary geometries for each cohort in the set.
1615+
""",
1616+
)
1617+
def cohort_geometries(
1618+
self,
1619+
cohort_set: base_params.cohorts,
1620+
) -> dict:
1621+
valid_cohort_sets = {
1622+
"admin1_month",
1623+
"admin1_quarter",
1624+
"admin1_year",
1625+
"admin2_month",
1626+
"admin2_quarter",
1627+
"admin2_year",
1628+
}
1629+
if cohort_set not in valid_cohort_sets:
1630+
raise ValueError(
1631+
f"{cohort_set!r} is not a valid cohort set. "
1632+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1633+
)
1634+
15541635
cohorts_analysis = self._cohorts_analysis
15551636

1556-
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1637+
# Cache to avoid repeated reads.
1638+
cache_key = (cohorts_analysis, cohort_set)
1639+
try:
1640+
geojson_data = self._cache_cohort_geometries[cache_key]
1641+
except KeyError:
1642+
major_version_path = self._major_version_path
1643+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.geojson"
15571644

1558-
# Read the manifest into a pandas dataframe.
1559-
with self.open_file(path) as f:
1560-
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1645+
with self.open_file(path) as f:
1646+
geojson_data = json.load(f)
15611647

1562-
# Ensure all column names are lower case.
1563-
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
1648+
self._cache_cohort_geometries[cache_key] = geojson_data
15641649

1565-
return df_cohorts
1650+
return geojson_data
15661651

15671652
@_check_types
15681653
@doc(

malariagen_data/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,9 +855,7 @@ def _value_error(
855855
value,
856856
expectation,
857857
):
858-
message = (
859-
f"Bad value for parameter {name}; expected {expectation}, " f"found {value!r}"
860-
)
858+
message = f"Bad value for parameter {name}; expected {expectation}, found {value!r}"
861859
raise ValueError(message)
862860

863861

@@ -935,6 +933,7 @@ def info(self, msg):
935933
self.flush()
936934

937935
def set_level(self, level):
936+
self._logger.setLevel(level)
938937
if self._handler is not None:
939938
self._handler.setLevel(level)
940939

notebooks/cohort_geometries.ipynb

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Cohort Geometries\n",
8+
"\n",
9+
"Demonstrates the `cohort_geometries()` method for accessing GeoJSON boundary data."
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import malariagen_data"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Set up the Ag3 data resource"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"ag3 = malariagen_data.Ag3(\n",
35+
" \"simplecache::gs://vo_agam_release_master_us_central1\",\n",
36+
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
37+
")\n",
38+
"ag3"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"metadata": {},
44+
"source": [
45+
"## Access cohort geometries"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": null,
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"geojson = ag3.cohort_geometries(cohort_set=\"admin1_year\")"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Inspect the returned data"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {},
68+
"outputs": [],
69+
"source": [
70+
"type(geojson)"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"geojson.keys()"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
88+
"len(geojson[\"features\"])"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"for f in geojson[\"features\"][:3]:\n",
98+
" print(f[\"properties\"])"
99+
]
100+
}
101+
],
102+
"metadata": {
103+
"kernelspec": {
104+
"display_name": "Python 3 (ipykernel)",
105+
"language": "python",
106+
"name": "python3"
107+
},
108+
"language_info": {
109+
"name": "python",
110+
"version": "3.10.0"
111+
}
112+
},
113+
"nbformat": 4,
114+
"nbformat_minor": 4
115+
}

notebooks/plot_haplotypes_frequencies.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
" sample_sets=[\"1232-VO-KE-OCHOMO-VMF00044\"],\n",
103103
" min_cohort_size=10,\n",
104104
")\n",
105-
"ag3.plot_frequencies_time_series(hap_xr)"
105+
"af1.plot_frequencies_time_series(hap_xr)"
106106
]
107107
},
108108
{

tests/anoph/conftest.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def simulate_exons(
334334
# keep things simple for now.
335335
if strand == "-":
336336
# Take exons in reverse order.
337-
exons == exons[::-1]
337+
exons = exons[::-1]
338338
for exon_ix, exon in enumerate(exons):
339339
first_exon = exon_ix == 0
340340
last_exon = exon_ix == len(exons) - 1
@@ -646,8 +646,8 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes, rng):
646646
# - sample_is_high_variance [1D array] [bool] [True or False for n_samples]
647647
# - samples [1D array] [str]
648648

649-
# Get a random probability for a sample being high variance, between 0 and 1.
650-
p_variance = rng.random()
649+
# Keep high variance sample prevalence stable for deterministic tests.
650+
p_variance = 0.1
651651

652652
# Open a zarr at the specified path.
653653
root = zarr.open(zarr_path, mode="w")
@@ -862,8 +862,8 @@ def simulate_cnv_discordant_read_calls(
862862
# - sample_is_high_variance [1D array] [bool] [True or False for n_samples]
863863
# - samples [1D array] [str for n_samples]
864864

865-
# Get a random probability for a sample being high variance, between 0 and 1.
866-
p_variance = rng.random()
865+
# Keep high variance sample prevalence stable for deterministic tests.
866+
p_variance = 0.1
867867

868868
# Get a random probability for choosing allele 1, between 0 and 1.
869869
p_allele = rng.random()
@@ -1408,23 +1408,55 @@ def write_metadata(
14081408
df_coh_ds.to_csv(dst_path, index=False)
14091409

14101410
# Create cohorts data by sampling from some real files.
1411-
src_path = (
1412-
self.fixture_dir
1413-
/ "vo_agam_release_master_us_central1"
1414-
/ "v3_cohorts"
1415-
/ "cohorts_20230516"
1416-
/ "cohorts_admin1_month.csv"
1417-
)
1418-
dst_path = (
1419-
self.bucket_path
1420-
/ "v3_cohorts"
1421-
/ "cohorts_20230516"
1422-
/ "cohorts_admin1_month.csv"
1423-
)
1424-
dst_path.parent.mkdir(parents=True, exist_ok=True)
1425-
with open(src_path, mode="r") as src, open(dst_path, mode="w") as dst:
1426-
for line in src.readlines()[:5]:
1427-
print(line, file=dst)
1411+
cohort_files = [
1412+
"cohorts_admin1_month.csv",
1413+
"cohorts_admin1_year.csv",
1414+
"cohorts_admin2_month.csv",
1415+
]
1416+
for cohort_file in cohort_files:
1417+
src_path = (
1418+
self.fixture_dir
1419+
/ "vo_agam_release_master_us_central1"
1420+
/ "v3_cohorts"
1421+
/ "cohorts_20230516"
1422+
/ cohort_file
1423+
)
1424+
if src_path.exists():
1425+
dst_path = (
1426+
self.bucket_path
1427+
/ "v3_cohorts"
1428+
/ "cohorts_20230516"
1429+
/ cohort_file
1430+
)
1431+
dst_path.parent.mkdir(parents=True, exist_ok=True)
1432+
with open(src_path, mode="r") as src, open(
1433+
dst_path, mode="w"
1434+
) as dst:
1435+
for line in src.readlines()[:5]:
1436+
print(line, file=dst)
1437+
1438+
# Copy cohort GeoJSON fixtures.
1439+
geojson_files = [
1440+
"cohorts_admin1_month.geojson",
1441+
"cohorts_admin1_year.geojson",
1442+
]
1443+
for geojson_file in geojson_files:
1444+
src_path = (
1445+
self.fixture_dir
1446+
/ "vo_agam_release_master_us_central1"
1447+
/ "v3_cohorts"
1448+
/ "cohorts_20230516"
1449+
/ geojson_file
1450+
)
1451+
if src_path.exists():
1452+
dst_path = (
1453+
self.bucket_path
1454+
/ "v3_cohorts"
1455+
/ "cohorts_20230516"
1456+
/ geojson_file
1457+
)
1458+
dst_path.parent.mkdir(parents=True, exist_ok=True)
1459+
shutil.copy2(src_path, dst_path)
14281460

14291461
# Create data catalog by sampling from some real metadata files.
14301462
src_path = (

0 commit comments

Comments
 (0)