Skip to content

Commit fe4f512

Browse files
authored
Merge pull request #1053 from adilraza99/GH435-cohort-geometries
Add cohort_geometries() to access cohort GeoJSON metadata
2 parents 5b3315e + a68d665 commit fe4f512

6 files changed

Lines changed: 296 additions & 0 deletions

File tree

malariagen_data/anoph/sample_metadata.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
from itertools import cycle
34
from typing import (
45
Any,
@@ -81,6 +82,7 @@ def __init__(
8182

8283
# Initialize cache attributes.
8384
self._cache_sample_metadata: Dict = dict()
85+
self._cache_cohort_geometries: Dict = dict()
8486

8587
def _metadata_paths(
8688
self,
@@ -1527,6 +1529,59 @@ def cohorts(
15271529

15281530
return df_cohorts
15291531

1532+
@_check_types
1533+
@doc(
1534+
summary="""
1535+
Read GeoJSON geometry data for a specific cohort set,
1536+
providing boundary geometries for each cohort.
1537+
""",
1538+
parameters=dict(
1539+
cohort_set="""
1540+
A cohort set name. Accepted values are:
1541+
"admin1_month", "admin1_quarter", "admin1_year",
1542+
"admin2_month", "admin2_quarter", "admin2_year".
1543+
""",
1544+
),
1545+
returns="""
1546+
A dict containing the parsed GeoJSON FeatureCollection,
1547+
with boundary geometries for each cohort in the set.
1548+
""",
1549+
)
1550+
def cohort_geometries(
1551+
self,
1552+
cohort_set: base_params.cohorts,
1553+
) -> dict:
1554+
valid_cohort_sets = {
1555+
"admin1_month",
1556+
"admin1_quarter",
1557+
"admin1_year",
1558+
"admin2_month",
1559+
"admin2_quarter",
1560+
"admin2_year",
1561+
}
1562+
if cohort_set not in valid_cohort_sets:
1563+
raise ValueError(
1564+
f"{cohort_set!r} is not a valid cohort set. "
1565+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1566+
)
1567+
1568+
cohorts_analysis = self._cohorts_analysis
1569+
1570+
# Cache to avoid repeated reads.
1571+
cache_key = (cohorts_analysis, cohort_set)
1572+
try:
1573+
geojson_data = self._cache_cohort_geometries[cache_key]
1574+
except KeyError:
1575+
major_version_path = self._major_version_path
1576+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.geojson"
1577+
1578+
with self.open_file(path) as f:
1579+
geojson_data = json.load(f)
1580+
1581+
self._cache_cohort_geometries[cache_key] = geojson_data
1582+
1583+
return geojson_data
1584+
15301585
@_check_types
15311586
@doc(
15321587
summary="""

notebooks/cohort_geometries.ipynb

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Cohort Geometries\n",
8+
"\n",
9+
"Demonstrates the `cohort_geometries()` method for accessing GeoJSON boundary data."
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import malariagen_data"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Set up the Ag3 data resource"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"ag3 = malariagen_data.Ag3(\n",
35+
" \"simplecache::gs://vo_agam_release_master_us_central1\",\n",
36+
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
37+
")\n",
38+
"ag3"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"metadata": {},
44+
"source": [
45+
"## Access cohort geometries"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": null,
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"geojson = ag3.cohort_geometries(cohort_set=\"admin1_year\")"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Inspect the returned data"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {},
68+
"outputs": [],
69+
"source": [
70+
"type(geojson)"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"geojson.keys()"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
88+
"len(geojson[\"features\"])"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"for f in geojson[\"features\"][:3]:\n",
98+
" print(f[\"properties\"])"
99+
]
100+
}
101+
],
102+
"metadata": {
103+
"kernelspec": {
104+
"display_name": "Python 3 (ipykernel)",
105+
"language": "python",
106+
"name": "python3"
107+
},
108+
"language_info": {
109+
"name": "python",
110+
"version": "3.10.0"
111+
}
112+
},
113+
"nbformat": 4,
114+
"nbformat_minor": 4
115+
}

tests/anoph/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,29 @@ def write_metadata(
14261426
for line in src.readlines()[:5]:
14271427
print(line, file=dst)
14281428

1429+
# Copy cohort GeoJSON fixtures.
1430+
geojson_files = [
1431+
"cohorts_admin1_month.geojson",
1432+
"cohorts_admin1_year.geojson",
1433+
]
1434+
for geojson_file in geojson_files:
1435+
src_path = (
1436+
self.fixture_dir
1437+
/ "vo_agam_release_master_us_central1"
1438+
/ "v3_cohorts"
1439+
/ "cohorts_20230516"
1440+
/ geojson_file
1441+
)
1442+
if src_path.exists():
1443+
dst_path = (
1444+
self.bucket_path
1445+
/ "v3_cohorts"
1446+
/ "cohorts_20230516"
1447+
/ geojson_file
1448+
)
1449+
dst_path.parent.mkdir(parents=True, exist_ok=True)
1450+
shutil.copy2(src_path, dst_path)
1451+
14291452
# Create data catalog by sampling from some real metadata files.
14301453
src_path = (
14311454
self.fixture_dir
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"type": "FeatureCollection",
3+
"features": [
4+
{
5+
"type": "Feature",
6+
"properties": {
7+
"cohort_id": "BF-01_arab_2008_04",
8+
"admin1_name": "Boucle du Mouhoun",
9+
"admin1_iso": "BF-01"
10+
},
11+
"geometry": {
12+
"type": "Polygon",
13+
"coordinates": [[[-4.5, 12.0], [-3.5, 12.0], [-3.5, 13.0], [-4.5, 13.0], [-4.5, 12.0]]]
14+
}
15+
},
16+
{
17+
"type": "Feature",
18+
"properties": {
19+
"cohort_id": "BF-02_colu_2011_07",
20+
"admin1_name": "Cascades",
21+
"admin1_iso": "BF-02"
22+
},
23+
"geometry": {
24+
"type": "Polygon",
25+
"coordinates": [[[-5.0, 10.0], [-4.0, 10.0], [-4.0, 11.0], [-5.0, 11.0], [-5.0, 10.0]]]
26+
}
27+
}
28+
]
29+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"type": "FeatureCollection",
3+
"features": [
4+
{
5+
"type": "Feature",
6+
"properties": {
7+
"cohort_id": "AO-LUA_colu_2009",
8+
"admin1_name": "Luanda",
9+
"admin1_iso": "AO-LUA"
10+
},
11+
"geometry": {
12+
"type": "Polygon",
13+
"coordinates": [[[13.0, -10.0], [14.0, -10.0], [14.0, -9.0], [13.0, -9.0], [13.0, -10.0]]]
14+
}
15+
},
16+
{
17+
"type": "Feature",
18+
"properties": {
19+
"cohort_id": "BF-01_arab_2008",
20+
"admin1_name": "Boucle du Mouhoun",
21+
"admin1_iso": "BF-01"
22+
},
23+
"geometry": {
24+
"type": "Polygon",
25+
"coordinates": [[[-4.5, 12.0], [-3.5, 12.0], [-3.5, 13.0], [-4.5, 13.0], [-4.5, 12.0]]]
26+
}
27+
}
28+
]
29+
}

tests/anoph/test_sample_metadata.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,3 +1465,48 @@ def test_cohort_data(fixture, api):
14651465
df_cohorts = api.cohorts(cohort_name)
14661466
# Check output.
14671467
validate_cohort_data(df_cohorts, cohort_data_expected_columns())
1468+
1469+
1470+
# ------------------------------------------------------------------
1471+
# Tests for cohort_geometries()
1472+
# ------------------------------------------------------------------
1473+
1474+
1475+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1476+
def test_cohort_geometries(fixture, api):
1477+
"""Test that GeoJSON geometry can be loaded for a valid cohort set."""
1478+
geojson = api.cohort_geometries("admin1_month")
1479+
assert isinstance(geojson, dict)
1480+
assert geojson["type"] == "FeatureCollection"
1481+
assert "features" in geojson
1482+
assert len(geojson["features"]) > 0
1483+
for feature in geojson["features"]:
1484+
assert feature["type"] == "Feature"
1485+
assert "geometry" in feature
1486+
assert "properties" in feature
1487+
assert "coordinates" in feature["geometry"]
1488+
1489+
1490+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1491+
def test_cohort_geometries_admin1_year(fixture, api):
1492+
"""Test that GeoJSON geometry can be loaded for admin1_year."""
1493+
geojson = api.cohort_geometries("admin1_year")
1494+
assert isinstance(geojson, dict)
1495+
assert geojson["type"] == "FeatureCollection"
1496+
assert len(geojson["features"]) > 0
1497+
1498+
1499+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1500+
def test_cohort_geometries_invalid_cohort_set(fixture, api):
1501+
"""Test that an invalid cohort_set raises ValueError."""
1502+
with suppress_type_checks():
1503+
with pytest.raises(ValueError, match="not a valid cohort set"):
1504+
api.cohort_geometries("invalid_set")
1505+
1506+
1507+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1508+
def test_cohort_geometries_cached(fixture, api):
1509+
"""Test that the second call returns the same cached object."""
1510+
g1 = api.cohort_geometries("admin1_month")
1511+
g2 = api.cohort_geometries("admin1_month")
1512+
assert g1 is g2

0 commit comments

Comments
 (0)