11import io
2+ import json
23import re
34from itertools import cycle
45from typing import (
@@ -82,6 +83,8 @@ def __init__(
8283
8384 # Initialize cache attributes.
8485 self ._cache_sample_metadata : Dict = dict ()
86+ self ._cache_cohorts : Dict = dict ()
87+ self ._cache_cohort_geometries : Dict = dict ()
8588
8689 def _metadata_paths (
8790 self ,
@@ -1522,7 +1525,11 @@ def _setup_cohort_queries(
15221525 A cohort set name. Accepted values are:
15231526 "admin1_month", "admin1_quarter", "admin1_year",
15241527 "admin2_month", "admin2_quarter", "admin2_year".
1525- """
1528+ """ ,
1529+ query = """
1530+ An optional pandas query string to filter the resulting
1531+ dataframe, e.g., "country == 'Burkina Faso'".
1532+ """ ,
15261533 ),
15271534 returns = """A dataframe of cohort data, one row per cohort. There are up to 18 columns:
15281535 `cohort_id` is the identifier of the cohort,
@@ -1549,20 +1556,98 @@ def _setup_cohort_queries(
15491556 def cohorts (
15501557 self ,
15511558 cohort_set : base_params .cohorts ,
1559+ query : Optional [str ] = None ,
15521560 ) -> pd .DataFrame :
1553- major_version_path = self ._major_version_path
1561+ valid_cohort_sets = {
1562+ "admin1_month" ,
1563+ "admin1_quarter" ,
1564+ "admin1_year" ,
1565+ "admin2_month" ,
1566+ "admin2_quarter" ,
1567+ "admin2_year" ,
1568+ }
1569+ if cohort_set not in valid_cohort_sets :
1570+ raise ValueError (
1571+ f"{ cohort_set !r} is not a valid cohort set. "
1572+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1573+ )
1574+
1575+ cohorts_analysis = self ._cohorts_analysis
1576+
1577+ # Cache to avoid repeated reads.
1578+ cache_key = (cohorts_analysis , cohort_set )
1579+ try :
1580+ df_cohorts = self ._cache_cohorts [cache_key ]
1581+ except KeyError :
1582+ major_version_path = self ._major_version_path
1583+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1584+
1585+ with self .open_file (path ) as f :
1586+ df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1587+
1588+ # Ensure all column names are lower case.
1589+ df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1590+
1591+ self ._cache_cohorts [cache_key ] = df_cohorts
1592+
1593+ if query is not None :
1594+ df_cohorts = df_cohorts .query (query )
1595+ df_cohorts = df_cohorts .reset_index (drop = True )
1596+
1597+ return df_cohorts .copy ()
1598+
1599+ @_check_types
1600+ @doc (
1601+ summary = """
1602+ Read GeoJSON geometry data for a specific cohort set,
1603+ providing boundary geometries for each cohort.
1604+ """ ,
1605+ parameters = dict (
1606+ cohort_set = """
1607+ A cohort set name. Accepted values are:
1608+ "admin1_month", "admin1_quarter", "admin1_year",
1609+ "admin2_month", "admin2_quarter", "admin2_year".
1610+ """ ,
1611+ ),
1612+ returns = """
1613+ A dict containing the parsed GeoJSON FeatureCollection,
1614+ with boundary geometries for each cohort in the set.
1615+ """ ,
1616+ )
1617+ def cohort_geometries (
1618+ self ,
1619+ cohort_set : base_params .cohorts ,
1620+ ) -> dict :
1621+ valid_cohort_sets = {
1622+ "admin1_month" ,
1623+ "admin1_quarter" ,
1624+ "admin1_year" ,
1625+ "admin2_month" ,
1626+ "admin2_quarter" ,
1627+ "admin2_year" ,
1628+ }
1629+ if cohort_set not in valid_cohort_sets :
1630+ raise ValueError (
1631+ f"{ cohort_set !r} is not a valid cohort set. "
1632+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1633+ )
1634+
15541635 cohorts_analysis = self ._cohorts_analysis
15551636
1556- path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1637+ # Cache to avoid repeated reads.
1638+ cache_key = (cohorts_analysis , cohort_set )
1639+ try :
1640+ geojson_data = self ._cache_cohort_geometries [cache_key ]
1641+ except KeyError :
1642+ major_version_path = self ._major_version_path
1643+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .geojson"
15571644
1558- # Read the manifest into a pandas dataframe.
1559- with self .open_file (path ) as f :
1560- df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1645+ with self .open_file (path ) as f :
1646+ geojson_data = json .load (f )
15611647
1562- # Ensure all column names are lower case.
1563- df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1648+ self ._cache_cohort_geometries [cache_key ] = geojson_data
15641649
1565- return df_cohorts
1650+ return geojson_data
15661651
15671652 @_check_types
15681653 @doc (
0 commit comments