11import io
2+ import json
23from itertools import cycle
34from typing import (
45 Any ,
@@ -81,6 +82,8 @@ def __init__(
8182
8283 # Initialize cache attributes.
8384 self ._cache_sample_metadata : Dict = dict ()
85+ self ._cache_cohorts : Dict = dict ()
86+ self ._cache_cohort_geometries : Dict = dict ()
8487
8588 def _metadata_paths (
8689 self ,
@@ -1496,7 +1499,11 @@ def _setup_cohort_queries(
14961499 A cohort set name. Accepted values are:
14971500 "admin1_month", "admin1_quarter", "admin1_year",
14981501 "admin2_month", "admin2_quarter", "admin2_year".
1499- """
1502+ """ ,
1503+ query = """
1504+ An optional pandas query string to filter the resulting
1505+ dataframe, e.g., "country == 'Burkina Faso'".
1506+ """ ,
15001507 ),
15011508 returns = """A dataframe of cohort data, one row per cohort. There are up to 18 columns:
15021509 `cohort_id` is the identifier of the cohort,
@@ -1523,20 +1530,98 @@ def _setup_cohort_queries(
15231530 def cohorts (
15241531 self ,
15251532 cohort_set : base_params .cohorts ,
1533+ query : Optional [str ] = None ,
15261534 ) -> pd .DataFrame :
1527- major_version_path = self ._major_version_path
1535+ valid_cohort_sets = {
1536+ "admin1_month" ,
1537+ "admin1_quarter" ,
1538+ "admin1_year" ,
1539+ "admin2_month" ,
1540+ "admin2_quarter" ,
1541+ "admin2_year" ,
1542+ }
1543+ if cohort_set not in valid_cohort_sets :
1544+ raise ValueError (
1545+ f"{ cohort_set !r} is not a valid cohort set. "
1546+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1547+ )
1548+
1549+ cohorts_analysis = self ._cohorts_analysis
1550+
1551+ # Cache to avoid repeated reads.
1552+ cache_key = (cohorts_analysis , cohort_set )
1553+ try :
1554+ df_cohorts = self ._cache_cohorts [cache_key ]
1555+ except KeyError :
1556+ major_version_path = self ._major_version_path
1557+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1558+
1559+ with self .open_file (path ) as f :
1560+ df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1561+
1562+ # Ensure all column names are lower case.
1563+ df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1564+
1565+ self ._cache_cohorts [cache_key ] = df_cohorts
1566+
1567+ if query is not None :
1568+ df_cohorts = df_cohorts .query (query )
1569+ df_cohorts = df_cohorts .reset_index (drop = True )
1570+
1571+ return df_cohorts .copy ()
1572+
1573+ @_check_types
1574+ @doc (
1575+ summary = """
1576+ Read GeoJSON geometry data for a specific cohort set,
1577+ providing boundary geometries for each cohort.
1578+ """ ,
1579+ parameters = dict (
1580+ cohort_set = """
1581+ A cohort set name. Accepted values are:
1582+ "admin1_month", "admin1_quarter", "admin1_year",
1583+ "admin2_month", "admin2_quarter", "admin2_year".
1584+ """ ,
1585+ ),
1586+ returns = """
1587+ A dict containing the parsed GeoJSON FeatureCollection,
1588+ with boundary geometries for each cohort in the set.
1589+ """ ,
1590+ )
1591+ def cohort_geometries (
1592+ self ,
1593+ cohort_set : base_params .cohorts ,
1594+ ) -> dict :
1595+ valid_cohort_sets = {
1596+ "admin1_month" ,
1597+ "admin1_quarter" ,
1598+ "admin1_year" ,
1599+ "admin2_month" ,
1600+ "admin2_quarter" ,
1601+ "admin2_year" ,
1602+ }
1603+ if cohort_set not in valid_cohort_sets :
1604+ raise ValueError (
1605+ f"{ cohort_set !r} is not a valid cohort set. "
1606+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1607+ )
1608+
15281609 cohorts_analysis = self ._cohorts_analysis
15291610
1530- path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1611+ # Cache to avoid repeated reads.
1612+ cache_key = (cohorts_analysis , cohort_set )
1613+ try :
1614+ geojson_data = self ._cache_cohort_geometries [cache_key ]
1615+ except KeyError :
1616+ major_version_path = self ._major_version_path
1617+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .geojson"
15311618
1532- # Read the manifest into a pandas dataframe.
1533- with self .open_file (path ) as f :
1534- df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1619+ with self .open_file (path ) as f :
1620+ geojson_data = json .load (f )
15351621
1536- # Ensure all column names are lower case.
1537- df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1622+ self ._cache_cohort_geometries [cache_key ] = geojson_data
15381623
1539- return df_cohorts
1624+ return geojson_data
15401625
15411626 @_check_types
15421627 @doc (
0 commit comments