1+ import difflib
12import io
23import json
4+ import re
35from itertools import cycle
46from typing import (
57 Any ,
@@ -82,6 +84,7 @@ def __init__(
8284
8385 # Initialize cache attributes.
8486 self ._cache_sample_metadata : Dict = dict ()
87+ self ._cache_cohorts : Dict = dict ()
8588 self ._cache_cohort_geometries : Dict = dict ()
8689
8790 def _metadata_paths (
@@ -704,6 +707,17 @@ def clear_extra_metadata(self):
704707 @doc (
705708 summary = "Access sample metadata for one or more sample sets." ,
706709 returns = "A dataframe of sample metadata, one row per sample." ,
710+ notes = """
711+ Some samples in the dataset are lab crosses — mosquitoes bred in
712+ the laboratory that have no real collection date. These samples
713+ use ``year=-1`` and ``month=-1`` as sentinel values. They may
714+ cause unexpected results in date-based analyses (e.g.,
715+ ``pd.to_datetime`` will fail on negative year values).
716+
717+ To exclude lab cross samples, use::
718+
719+ df = api.sample_metadata(sample_query="year >= 0")
720+ """ ,
707721 )
708722 def sample_metadata (
709723 self ,
@@ -783,12 +797,65 @@ def sample_metadata(
783797 if prepared_sample_query is not None :
784798 # Assume a pandas query string.
785799 sample_query_options = sample_query_options or {}
800+
801+ # Save a reference to the pre-query DataFrame so we can detect
802+ # zero-result queries and provide a helpful warning.
803+ df_before_query = df_samples
804+
786805 # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
787806 df_samples = df_samples .query (
788807 prepared_sample_query , ** sample_query_options , engine = "python"
789808 )
790809 df_samples = df_samples .reset_index (drop = True )
791810
811+ # Warn if query returned zero results on a non-empty dataset.
812+ # Provide fuzzy-match suggestions so users can spot typos,
813+ # case mismatches, or partial-value issues.
814+ if len (df_samples ) == 0 and len (df_before_query ) > 0 :
815+ hint_lines = [
816+ f"sample_metadata() returned 0 samples for query: { prepared_sample_query !r} ." ,
817+ ]
818+
819+ # Extract column == 'value' pairs from the query.
820+ col_val_pairs = re .findall (
821+ r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]" ,
822+ prepared_sample_query ,
823+ )
824+
825+ for col_name , queried_val in col_val_pairs :
826+ # If the column name is not recognised, suggest
827+ # close column names.
828+ if col_name not in df_before_query .columns :
829+ close_cols = difflib .get_close_matches (
830+ col_name ,
831+ df_before_query .columns .tolist (),
832+ n = 3 ,
833+ cutoff = 0.6 ,
834+ )
835+ if close_cols :
836+ hint_lines .append (
837+ f"Column { col_name !r} not found. "
838+ f"Did you mean: { close_cols } ?"
839+ )
840+ continue
841+
842+ # For string columns, suggest close values.
843+ if df_before_query [col_name ].dtype == object :
844+ valid_vals = (
845+ df_before_query [col_name ].dropna ().unique ().tolist ()
846+ )
847+ close_vals = difflib .get_close_matches (
848+ queried_val , valid_vals , n = 5 , cutoff = 0.6
849+ )
850+ if close_vals :
851+ hint_lines .append (
852+ f"Value { queried_val !r} not found in "
853+ f"column { col_name !r} . "
854+ f"Did you mean: { close_vals } ?"
855+ )
856+
857+ warnings .warn ("\n " .join (hint_lines ), UserWarning , stacklevel = 2 )
858+
792859 # Apply the sample_indices, if there are any.
793860 # Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
794861 if sample_indices is not None :
@@ -1487,7 +1554,11 @@ def _setup_cohort_queries(
14871554 A cohort set name. Accepted values are:
14881555 "admin1_month", "admin1_quarter", "admin1_year",
14891556 "admin2_month", "admin2_quarter", "admin2_year".
1490- """
1557+ """ ,
1558+ query = """
1559+ An optional pandas query string to filter the resulting
1560+ dataframe, e.g., "country == 'Burkina Faso'".
1561+ """ ,
14911562 ),
14921563 returns = """A dataframe of cohort data, one row per cohort. There are up to 18 columns:
14931564 `cohort_id` is the identifier of the cohort,
@@ -1514,20 +1585,45 @@ def _setup_cohort_queries(
15141585 def cohorts (
15151586 self ,
15161587 cohort_set : base_params .cohorts ,
1588+ query : Optional [str ] = None ,
15171589 ) -> pd .DataFrame :
1518- major_version_path = self ._major_version_path
1590+ valid_cohort_sets = {
1591+ "admin1_month" ,
1592+ "admin1_quarter" ,
1593+ "admin1_year" ,
1594+ "admin2_month" ,
1595+ "admin2_quarter" ,
1596+ "admin2_year" ,
1597+ }
1598+ if cohort_set not in valid_cohort_sets :
1599+ raise ValueError (
1600+ f"{ cohort_set !r} is not a valid cohort set. "
1601+ f"Accepted values are: { sorted (valid_cohort_sets )} ."
1602+ )
1603+
15191604 cohorts_analysis = self ._cohorts_analysis
15201605
1521- path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1606+ # Cache to avoid repeated reads.
1607+ cache_key = (cohorts_analysis , cohort_set )
1608+ try :
1609+ df_cohorts = self ._cache_cohorts [cache_key ]
1610+ except KeyError :
1611+ major_version_path = self ._major_version_path
1612+ path = f"{ major_version_path [:2 ]} _cohorts/cohorts_{ cohorts_analysis } /cohorts_{ cohort_set } .csv"
1613+
1614+ with self .open_file (path ) as f :
1615+ df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1616+
1617+ # Ensure all column names are lower case.
1618+ df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
15221619
1523- # Read the manifest into a pandas dataframe.
1524- with self .open_file (path ) as f :
1525- df_cohorts = pd .read_csv (f , sep = "," , na_values = "" )
1620+ self ._cache_cohorts [cache_key ] = df_cohorts
15261621
1527- # Ensure all column names are lower case.
1528- df_cohorts .columns = [c .lower () for c in df_cohorts .columns ] # type: ignore
1622+ if query is not None :
1623+ df_cohorts = df_cohorts .query (query )
1624+ df_cohorts = df_cohorts .reset_index (drop = True )
15291625
1530- return df_cohorts
1626+ return df_cohorts . copy ()
15311627
15321628 @_check_types
15331629 @doc (
0 commit comments