11import io
2+ import json
23from itertools import cycle
34from typing import (
45 Any ,
@@ -81,6 +82,8 @@ def __init__(
8182
8283 # Initialize cache attributes.
8384 self ._cache_sample_metadata : Dict = dict ()
85+ self ._cache_cohorts : Dict = dict ()
86+ self ._cache_cohort_geometries : Dict = dict ()
8487
8588 def _metadata_paths (
8689 self ,
@@ -1485,7 +1488,11 @@ def _setup_cohort_queries(
14851488 A cohort set name. Accepted values are:
14861489 "admin1_month", "admin1_quarter", "admin1_year",
14871490 "admin2_month", "admin2_quarter", "admin2_year".
1488- """
1491+ """ ,
1492+ query = """
1493+ An optional pandas query string to filter the resulting
1494+ dataframe, e.g., "country == 'Burkina Faso'".
1495+ """ ,
14891496 ),
14901497 returns = """A dataframe of cohort data, one row per cohort. There are up to 18 columns:
14911498 `cohort_id` is the identifier of the cohort,
@@ -1512,20 +1519,98 @@ def _setup_cohort_queries(
15121519 def cohorts (
15131520 self ,
15141521 cohort_set : base_params .cohorts ,
1522+ query : Optional [str ] = None ,
15151523 ) -> pd .DataFrame :
1516- major_version_path = self ._major_version_path
1524+ valid_cohort_sets = {
1525+ "admin1_month" ,
1526+ "admin1_quarter" ,
1527+ "admin1_year" ,
1528+ "admin2_month" ,
1529+ "admin2_quarter" ,
1530+ "admin2_year" ,
1531+ }
1532+ if cohort_set not in valid_cohort_sets :
1533+ raise ValueError (
1534+ f"{ cohort_set !r} is not a valid cohort set. "
1535+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1536+ )
1537+
1538+ cohorts_analysis = self ._cohorts_analysis
1539+
1540+ # Cache to avoid repeated reads.
1541+ cache_key = (cohorts_analysis , cohort_set )
1542+ try :
1543+ df_cohorts = self ._cache_cohorts [cache_key ]
1544+ except KeyError :
1545+ major_version_path = self ._major_version_path
1546+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1547+
1548+ with self .open_file (path ) as f :
1549+ df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1550+
1551+ # Ensure all column names are lower case.
1552+ df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1553+
1554+ self ._cache_cohorts [cache_key ] = df_cohorts
1555+
1556+ if query is not None :
1557+ df_cohorts = df_cohorts .query (query )
1558+ df_cohorts = df_cohorts .reset_index (drop = True )
1559+
1560+ return df_cohorts .copy ()
1561+
1562+ @_check_types
1563+ @doc (
1564+ summary = """
1565+ Read GeoJSON geometry data for a specific cohort set,
1566+ providing boundary geometries for each cohort.
1567+ """ ,
1568+ parameters = dict (
1569+ cohort_set = """
1570+ A cohort set name. Accepted values are:
1571+ "admin1_month", "admin1_quarter", "admin1_year",
1572+ "admin2_month", "admin2_quarter", "admin2_year".
1573+ """ ,
1574+ ),
1575+ returns = """
1576+ A dict containing the parsed GeoJSON FeatureCollection,
1577+ with boundary geometries for each cohort in the set.
1578+ """ ,
1579+ )
1580+ def cohort_geometries (
1581+ self ,
1582+ cohort_set : base_params .cohorts ,
1583+ ) -> dict :
1584+ valid_cohort_sets = {
1585+ "admin1_month" ,
1586+ "admin1_quarter" ,
1587+ "admin1_year" ,
1588+ "admin2_month" ,
1589+ "admin2_quarter" ,
1590+ "admin2_year" ,
1591+ }
1592+ if cohort_set not in valid_cohort_sets :
1593+ raise ValueError (
1594+ f"{ cohort_set !r} is not a valid cohort set. "
1595+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1596+ )
1597+
15171598 cohorts_analysis = self ._cohorts_analysis
15181599
1519- path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1600+ # Cache to avoid repeated reads.
1601+ cache_key = (cohorts_analysis , cohort_set )
1602+ try :
1603+ geojson_data = self ._cache_cohort_geometries [cache_key ]
1604+ except KeyError :
1605+ major_version_path = self ._major_version_path
1606+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .geojson"
15201607
1521- # Read the manifest into a pandas dataframe.
1522- with self .open_file (path ) as f :
1523- df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1608+ with self .open_file (path ) as f :
1609+ geojson_data = json .load (f )
15241610
1525- # Ensure all column names are lower case.
1526- df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1611+ self ._cache_cohort_geometries [cache_key ] = geojson_data
15271612
1528- return df_cohorts
1613+ return geojson_data
15291614
15301615 @_check_types
15311616 @doc (
0 commit comments