1+ import difflib
12import io
3+ import json
4+ import re
25from itertools import cycle
36from typing import (
47 Any ,
@@ -81,6 +84,8 @@ def __init__(
8184
8285 # Initialize cache attributes.
8386 self ._cache_sample_metadata : Dict = dict ()
87+ self ._cache_cohorts : Dict = dict ()
88+ self ._cache_cohort_geometries : Dict = dict ()
8489
8590 def _metadata_paths (
8691 self ,
@@ -702,6 +707,17 @@ def clear_extra_metadata(self):
702707 @doc (
703708 summary = "Access sample metadata for one or more sample sets." ,
704709 returns = "A dataframe of sample metadata, one row per sample." ,
710+ notes = """
711+ Some samples in the dataset are lab crosses — mosquitoes bred in
712+ the laboratory that have no real collection date. These samples
713+ use ``year=-1`` and ``month=-1`` as sentinel values. They may
714+ cause unexpected results in date-based analyses (e.g.,
715+ ``pd.to_datetime`` will fail on negative year values).
716+
717+ To exclude lab cross samples, use::
718+
719+ df = api.sample_metadata(sample_query="year >= 0")
720+ """ ,
705721 )
706722 def sample_metadata (
707723 self ,
@@ -781,12 +797,65 @@ def sample_metadata(
781797 if prepared_sample_query is not None :
782798 # Assume a pandas query string.
783799 sample_query_options = sample_query_options or {}
800+
801+ # Save a reference to the pre-query DataFrame so we can detect
802+ # zero-result queries and provide a helpful warning.
803+ df_before_query = df_samples
804+
784805 # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
785806 df_samples = df_samples .query (
786807 prepared_sample_query , ** sample_query_options , engine = "python"
787808 )
788809 df_samples = df_samples .reset_index (drop = True )
789810
811+ # Warn if query returned zero results on a non-empty dataset.
812+ # Provide fuzzy-match suggestions so users can spot typos,
813+ # case mismatches, or partial-value issues.
814+ if len (df_samples ) == 0 and len (df_before_query ) > 0 :
815+ hint_lines = [
816+ f"sample_metadata() returned 0 samples for query: { prepared_sample_query !r} ." ,
817+ ]
818+
819+ # Extract column == 'value' pairs from the query.
820+ col_val_pairs = re .findall (
821+ r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]" ,
822+ prepared_sample_query ,
823+ )
824+
825+ for col_name , queried_val in col_val_pairs :
826+ # If the column name is not recognised, suggest
827+ # close column names.
828+ if col_name not in df_before_query .columns :
829+ close_cols = difflib .get_close_matches (
830+ col_name ,
831+ df_before_query .columns .tolist (),
832+ n = 3 ,
833+ cutoff = 0.6 ,
834+ )
835+ if close_cols :
836+ hint_lines .append (
837+ f"Column { col_name !r} not found. "
838+ f"Did you mean: { close_cols } ?"
839+ )
840+ continue
841+
842+ # For string columns, suggest close values.
843+ if df_before_query [col_name ].dtype == object :
844+ valid_vals = (
845+ df_before_query [col_name ].dropna ().unique ().tolist ()
846+ )
847+ close_vals = difflib .get_close_matches (
848+ queried_val , valid_vals , n = 5 , cutoff = 0.6
849+ )
850+ if close_vals :
851+ hint_lines .append (
852+ f"Value { queried_val !r} not found in "
853+ f"column { col_name !r} . "
854+ f"Did you mean: { close_vals } ?"
855+ )
856+
857+ warnings .warn ("\n " .join (hint_lines ), UserWarning , stacklevel = 2 )
858+
790859 # Apply the sample_indices, if there are any.
791860 # Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
792861 if sample_indices is not None :
@@ -1485,7 +1554,11 @@ def _setup_cohort_queries(
14851554 A cohort set name. Accepted values are:
14861555 "admin1_month", "admin1_quarter", "admin1_year",
14871556 "admin2_month", "admin2_quarter", "admin2_year".
1488- """
1557+ """ ,
1558+ query = """
1559+ An optional pandas query string to filter the resulting
1560+ dataframe, e.g., "country == 'Burkina Faso'".
1561+ """ ,
14891562 ),
14901563 returns = """A dataframe of cohort data, one row per cohort. There are up to 18 columns:
14911564 `cohort_id` is the identifier of the cohort,
@@ -1512,20 +1585,98 @@ def _setup_cohort_queries(
15121585 def cohorts (
15131586 self ,
15141587 cohort_set : base_params .cohorts ,
1588+ query : Optional [str ] = None ,
15151589 ) -> pd .DataFrame :
1516- major_version_path = self ._major_version_path
1590+ valid_cohort_sets = {
1591+ "admin1_month" ,
1592+ "admin1_quarter" ,
1593+ "admin1_year" ,
1594+ "admin2_month" ,
1595+ "admin2_quarter" ,
1596+ "admin2_year" ,
1597+ }
1598+ if cohort_set not in valid_cohort_sets :
1599+ raise ValueError (
1600+ f"{ cohort_set !r} is not a valid cohort set. "
1601+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1602+ )
1603+
15171604 cohorts_analysis = self ._cohorts_analysis
15181605
1519- path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1606+ # Cache to avoid repeated reads.
1607+ cache_key = (cohorts_analysis , cohort_set )
1608+ try :
1609+ df_cohorts = self ._cache_cohorts [cache_key ]
1610+ except KeyError :
1611+ major_version_path = self ._major_version_path
1612+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1613+
1614+ with self .open_file (path ) as f :
1615+ df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1616+
1617+ # Ensure all column names are lower case.
1618+ df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1619+
1620+ self ._cache_cohorts [cache_key ] = df_cohorts
1621+
1622+ if query is not None :
1623+ df_cohorts = df_cohorts .query (query )
1624+ df_cohorts = df_cohorts .reset_index (drop = True )
1625+
1626+ return df_cohorts .copy ()
1627+
1628+ @_check_types
1629+ @doc (
1630+ summary = """
1631+ Read GeoJSON geometry data for a specific cohort set,
1632+ providing boundary geometries for each cohort.
1633+ """ ,
1634+ parameters = dict (
1635+ cohort_set = """
1636+ A cohort set name. Accepted values are:
1637+ "admin1_month", "admin1_quarter", "admin1_year",
1638+ "admin2_month", "admin2_quarter", "admin2_year".
1639+ """ ,
1640+ ),
1641+ returns = """
1642+ A dict containing the parsed GeoJSON FeatureCollection,
1643+ with boundary geometries for each cohort in the set.
1644+ """ ,
1645+ )
1646+ def cohort_geometries (
1647+ self ,
1648+ cohort_set : base_params .cohorts ,
1649+ ) -> dict :
1650+ valid_cohort_sets = {
1651+ "admin1_month" ,
1652+ "admin1_quarter" ,
1653+ "admin1_year" ,
1654+ "admin2_month" ,
1655+ "admin2_quarter" ,
1656+ "admin2_year" ,
1657+ }
1658+ if cohort_set not in valid_cohort_sets :
1659+ raise ValueError (
1660+ f"{ cohort_set !r} is not a valid cohort set. "
1661+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1662+ )
1663+
1664+ cohorts_analysis = self ._cohorts_analysis
1665+
1666+ # Cache to avoid repeated reads.
1667+ cache_key = (cohorts_analysis , cohort_set )
1668+ try :
1669+ geojson_data = self ._cache_cohort_geometries [cache_key ]
1670+ except KeyError :
1671+ major_version_path = self ._major_version_path
1672+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .geojson"
15201673
1521- # Read the manifest into a pandas dataframe.
1522- with self .open_file (path ) as f :
1523- df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1674+ with self .open_file (path ) as f :
1675+ geojson_data = json .load (f )
15241676
1525- # Ensure all column names are lower case.
1526- df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1677+ self ._cache_cohort_geometries [cache_key ] = geojson_data
15271678
1528- return df_cohorts
1679+ return geojson_data
15291680
15301681 @_check_types
15311682 @doc (
0 commit comments