Skip to content

Commit 53b3a29

Browse files
committed
Improve robustness of cohorts() metadata access
- improve validation and handling of cohort metadata - ensure consistent behaviour when cohort data is missing - maintain compatibility with cohort_geometries changes introduced in #1053 - add tests covering edge cases and error handling
1 parent fe4f512 commit 53b3a29

5 files changed

Lines changed: 148 additions & 26 deletions

File tree

malariagen_data/anoph/sample_metadata.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282

8383
# Initialize cache attributes.
8484
self._cache_sample_metadata: Dict = dict()
85+
self._cache_cohorts: Dict = dict()
8586
self._cache_cohort_geometries: Dict = dict()
8687

8788
def _metadata_paths(
@@ -1487,7 +1488,11 @@ def _setup_cohort_queries(
14871488
A cohort set name. Accepted values are:
14881489
"admin1_month", "admin1_quarter", "admin1_year",
14891490
"admin2_month", "admin2_quarter", "admin2_year".
1490-
"""
1491+
""",
1492+
query="""
1493+
An optional pandas query string to filter the resulting
1494+
dataframe, e.g., "country == 'Burkina Faso'".
1495+
""",
14911496
),
14921497
returns="""A dataframe of cohort data, one row per cohort. There are up to 18 columns:
14931498
`cohort_id` is the identifier of the cohort,
@@ -1514,20 +1519,45 @@ def _setup_cohort_queries(
15141519
def cohorts(
15151520
self,
15161521
cohort_set: base_params.cohorts,
1522+
query: Optional[str] = None,
15171523
) -> pd.DataFrame:
1518-
major_version_path = self._major_version_path
1524+
valid_cohort_sets = {
1525+
"admin1_month",
1526+
"admin1_quarter",
1527+
"admin1_year",
1528+
"admin2_month",
1529+
"admin2_quarter",
1530+
"admin2_year",
1531+
}
1532+
if cohort_set not in valid_cohort_sets:
1533+
raise ValueError(
1534+
f"{cohort_set!r} is not a valid cohort set. "
1535+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1536+
)
1537+
15191538
cohorts_analysis = self._cohorts_analysis
15201539

1521-
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1540+
# Cache to avoid repeated reads.
1541+
cache_key = (cohorts_analysis, cohort_set)
1542+
try:
1543+
df_cohorts = self._cache_cohorts[cache_key]
1544+
except KeyError:
1545+
major_version_path = self._major_version_path
1546+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1547+
1548+
with self.open_file(path) as f:
1549+
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1550+
1551+
# Ensure all column names are lower case.
1552+
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
15221553

1523-
# Read the manifest into a pandas dataframe.
1524-
with self.open_file(path) as f:
1525-
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1554+
self._cache_cohorts[cache_key] = df_cohorts
15261555

1527-
# Ensure all column names are lower case.
1528-
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
1556+
if query is not None:
1557+
df_cohorts = df_cohorts.query(query)
1558+
df_cohorts = df_cohorts.reset_index(drop=True)
15291559

1530-
return df_cohorts
1560+
return df_cohorts.copy()
15311561

15321562
@_check_types
15331563
@doc(

tests/anoph/conftest.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,23 +1408,32 @@ def write_metadata(
14081408
df_coh_ds.to_csv(dst_path, index=False)
14091409

14101410
# Create cohorts data by sampling from some real files.
1411-
src_path = (
1412-
self.fixture_dir
1413-
/ "vo_agam_release_master_us_central1"
1414-
/ "v3_cohorts"
1415-
/ "cohorts_20230516"
1416-
/ "cohorts_admin1_month.csv"
1417-
)
1418-
dst_path = (
1419-
self.bucket_path
1420-
/ "v3_cohorts"
1421-
/ "cohorts_20230516"
1422-
/ "cohorts_admin1_month.csv"
1423-
)
1424-
dst_path.parent.mkdir(parents=True, exist_ok=True)
1425-
with open(src_path, mode="r") as src, open(dst_path, mode="w") as dst:
1426-
for line in src.readlines()[:5]:
1427-
print(line, file=dst)
1411+
cohort_files = [
1412+
"cohorts_admin1_month.csv",
1413+
"cohorts_admin1_year.csv",
1414+
"cohorts_admin2_month.csv",
1415+
]
1416+
for cohort_file in cohort_files:
1417+
src_path = (
1418+
self.fixture_dir
1419+
/ "vo_agam_release_master_us_central1"
1420+
/ "v3_cohorts"
1421+
/ "cohorts_20230516"
1422+
/ cohort_file
1423+
)
1424+
if src_path.exists():
1425+
dst_path = (
1426+
self.bucket_path
1427+
/ "v3_cohorts"
1428+
/ "cohorts_20230516"
1429+
/ cohort_file
1430+
)
1431+
dst_path.parent.mkdir(parents=True, exist_ok=True)
1432+
with open(src_path, mode="r") as src, open(
1433+
dst_path, mode="w"
1434+
) as dst:
1435+
for line in src.readlines()[:5]:
1436+
print(line, file=dst)
14281437

14291438
# Copy cohort GeoJSON fixtures.
14301439
geojson_files = [
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
cohort_id,cohort_size,country,country_alpha2,country_alpha3,taxon,year,admin1_name,admin1_iso,admin1_geoboundaries_shape_id,admin1_representative_longitude,admin1_representative_latitude
2+
AO-LUA_colu_2009,81,Angola,AO,AGO,coluzzii,2009,Luanda,AO-LUA,26408823B49174064004395,13.679075010193182,-9.592222213499952
3+
BF-01_arab_2008,1,Burkina Faso,BF,BFA,arabiensis,2008,Boucle du Mouhoun,BF-01,92566538B98190668782446,-3.592255305233366,12.479899304500035
4+
BF-01_colu_2008,4,Burkina Faso,BF,BFA,coluzzii,2008,Boucle du Mouhoun,BF-01,92566538B98190668782446,-3.592255305233366,12.479899304500035
5+
BF-02_colu_2011,18,Burkina Faso,BF,BFA,coluzzii,2011,Cascades,BF-02,92566538B44525923588019,-4.482809923408134,10.30846
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
cohort_id,cohort_size,country,country_alpha2,country_alpha3,taxon,year,quarter,month,admin1_name,admin1_iso,admin1_geoboundaries_shape_id,admin1_representative_longitude,admin1_representative_latitude,admin2_name,admin2_iso,admin2_geoboundaries_shape_id,admin2_representative_longitude,admin2_representative_latitude
2+
AO-LUA-LUA_colu_2009_04,81,Angola,AO,AGO,coluzzii,2009,2,4,Luanda,AO-LUA,26408823B49174064004395,13.679075010193182,-9.592222213499952,Luanda,AO-LUA-LUA,26408823B49174064004396,13.679075010193182,-9.592222213499952
3+
BF-01-BAN_arab_2008_11,1,Burkina Faso,BF,BFA,arabiensis,2008,4,11,Boucle du Mouhoun,BF-01,92566538B98190668782446,-3.592255305233366,12.479899304500035,Banwa,BF-01-BAN,92566538B98190668782447,-3.592255305233366,12.479899304500035

tests/anoph/test_sample_metadata.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,47 @@ def cohort_data_expected_columns():
14461446
}
14471447

14481448

1449+
def cohort_data_admin1_year_expected_columns():
1450+
return {
1451+
"cohort_id": "O",
1452+
"cohort_size": "i",
1453+
"country": "O",
1454+
"country_alpha2": "O",
1455+
"country_alpha3": "O",
1456+
"taxon": "O",
1457+
"year": "i",
1458+
"admin1_name": "O",
1459+
"admin1_iso": "O",
1460+
"admin1_geoboundaries_shape_id": "O",
1461+
"admin1_representative_longitude": "f",
1462+
"admin1_representative_latitude": "f",
1463+
}
1464+
1465+
1466+
def cohort_data_admin2_month_expected_columns():
1467+
return {
1468+
"cohort_id": "O",
1469+
"cohort_size": "i",
1470+
"country": "O",
1471+
"country_alpha2": "O",
1472+
"country_alpha3": "O",
1473+
"taxon": "O",
1474+
"year": "i",
1475+
"quarter": "i",
1476+
"month": "i",
1477+
"admin1_name": "O",
1478+
"admin1_iso": "O",
1479+
"admin1_geoboundaries_shape_id": "O",
1480+
"admin1_representative_longitude": "f",
1481+
"admin1_representative_latitude": "f",
1482+
"admin2_name": "O",
1483+
"admin2_iso": "O",
1484+
"admin2_geoboundaries_shape_id": "O",
1485+
"admin2_representative_longitude": "f",
1486+
"admin2_representative_latitude": "f",
1487+
}
1488+
1489+
14491490
def validate_cohort_data(df, expected_columns):
14501491
# Check column names.
14511492
# Note: insertion order in dictionary keys is guaranteed since Python 3.7
@@ -1467,6 +1508,40 @@ def test_cohort_data(fixture, api):
14671508
validate_cohort_data(df_cohorts, cohort_data_expected_columns())
14681509

14691510

1511+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1512+
def test_cohort_data_admin1_year(fixture, api):
1513+
df_cohorts = api.cohorts("admin1_year")
1514+
validate_cohort_data(df_cohorts, cohort_data_admin1_year_expected_columns())
1515+
1516+
1517+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1518+
def test_cohort_data_admin2_month(fixture, api):
1519+
df_cohorts = api.cohorts("admin2_month")
1520+
validate_cohort_data(df_cohorts, cohort_data_admin2_month_expected_columns())
1521+
1522+
1523+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1524+
def test_cohort_data_invalid_cohort_set(fixture, api):
1525+
with pytest.raises(ValueError, match="is not a valid cohort set"):
1526+
api.cohorts("invalid_name")
1527+
1528+
1529+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1530+
def test_cohort_data_with_query(fixture, api):
1531+
df_all = api.cohorts("admin1_month")
1532+
df_filtered = api.cohorts("admin1_month", query="country == 'Burkina Faso'")
1533+
assert len(df_filtered) > 0
1534+
assert (df_filtered["country"] == "Burkina Faso").all()
1535+
assert len(df_filtered) < len(df_all)
1536+
1537+
1538+
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
1539+
def test_cohort_data_cached(fixture, api):
1540+
df1 = api.cohorts("admin1_month")
1541+
df2 = api.cohorts("admin1_month")
1542+
assert_frame_equal(df1, df2)
1543+
1544+
14701545
# ------------------------------------------------------------------
14711546
# Tests for cohort_geometries()
14721547
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)