Skip to content

Commit f7d09a5

Browse files
Merge branch 'master' into GH1097-migrate-extras-to-groups
2 parents ea1340f + 32d906b commit f7d09a5

3 files changed

Lines changed: 117 additions & 1 deletion

File tree

malariagen_data/anoph/base_params.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@
6969
str,
7070
"""
7171
A pandas query string to be evaluated against the sample metadata, to
72-
select samples to be included in the returned data.
72+
select samples to be included in the returned data. E.g.,
73+
"country == 'Uganda'". If the query returns zero results, a warning
74+
will be emitted with fuzzy-match suggestions for possible typos or
75+
case mismatches.
7376
""",
7477
]
7578

malariagen_data/anoph/sample_metadata.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import difflib
12
import io
23
import json
4+
import re
35
from itertools import cycle
46
from typing import (
57
Any,
@@ -705,6 +707,17 @@ def clear_extra_metadata(self):
705707
@doc(
706708
summary="Access sample metadata for one or more sample sets.",
707709
returns="A dataframe of sample metadata, one row per sample.",
710+
notes="""
711+
Some samples in the dataset are lab crosses — mosquitoes bred in
712+
the laboratory that have no real collection date. These samples
713+
use ``year=-1`` and ``month=-1`` as sentinel values. They may
714+
cause unexpected results in date-based analyses (e.g.,
715+
``pd.to_datetime`` will fail on negative year values).
716+
717+
To exclude lab cross samples, use::
718+
719+
df = api.sample_metadata(sample_query="year >= 0")
720+
""",
708721
)
709722
def sample_metadata(
710723
self,
@@ -784,12 +797,65 @@ def sample_metadata(
784797
if prepared_sample_query is not None:
785798
# Assume a pandas query string.
786799
sample_query_options = sample_query_options or {}
800+
801+
# Save a reference to the pre-query DataFrame so we can detect
802+
# zero-result queries and provide a helpful warning.
803+
df_before_query = df_samples
804+
787805
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
788806
df_samples = df_samples.query(
789807
prepared_sample_query, **sample_query_options, engine="python"
790808
)
791809
df_samples = df_samples.reset_index(drop=True)
792810

811+
# Warn if query returned zero results on a non-empty dataset.
812+
# Provide fuzzy-match suggestions so users can spot typos,
813+
# case mismatches, or partial-value issues.
814+
if len(df_samples) == 0 and len(df_before_query) > 0:
815+
hint_lines = [
816+
f"sample_metadata() returned 0 samples for query: {prepared_sample_query!r}.",
817+
]
818+
819+
# Extract column == 'value' pairs from the query.
820+
col_val_pairs = re.findall(
821+
r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]",
822+
prepared_sample_query,
823+
)
824+
825+
for col_name, queried_val in col_val_pairs:
826+
# If the column name is not recognised, suggest
827+
# close column names.
828+
if col_name not in df_before_query.columns:
829+
close_cols = difflib.get_close_matches(
830+
col_name,
831+
df_before_query.columns.tolist(),
832+
n=3,
833+
cutoff=0.6,
834+
)
835+
if close_cols:
836+
hint_lines.append(
837+
f"Column {col_name!r} not found. "
838+
f"Did you mean: {close_cols}?"
839+
)
840+
continue
841+
842+
# For string columns, suggest close values.
843+
if df_before_query[col_name].dtype == object:
844+
valid_vals = (
845+
df_before_query[col_name].dropna().unique().tolist()
846+
)
847+
close_vals = difflib.get_close_matches(
848+
queried_val, valid_vals, n=5, cutoff=0.6
849+
)
850+
if close_vals:
851+
hint_lines.append(
852+
f"Value {queried_val!r} not found in "
853+
f"column {col_name!r}. "
854+
f"Did you mean: {close_vals}?"
855+
)
856+
857+
warnings.warn("\n".join(hint_lines), UserWarning, stacklevel=2)
858+
793859
# Apply the sample_indices, if there are any.
794860
# Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
795861
if sample_indices is not None:

tests/anoph/test_sample_metadata.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,53 @@ def test_cohort_data(fixture, api):
15081508
validate_cohort_data(df_cohorts, cohort_data_expected_columns())
15091509

15101510

1511+
@parametrize_with_cases("fixture,api", cases=".")
1512+
def test_sample_metadata_warns_on_zero_results_with_suggestions(
1513+
fixture, api: AnophelesSampleMetadata
1514+
):
1515+
"""Test that a UserWarning with fuzzy suggestions is raised when a query
1516+
returns 0 results due to a typo or case mismatch.
1517+
1518+
Regression test for https://github.com/malariagen/malariagen-data-python/issues/1083
1519+
"""
1520+
# Get a valid country name from the metadata so we can construct
1521+
# a deliberately wrong-cased query.
1522+
df_all = api.sample_metadata()
1523+
if "country" not in df_all.columns or df_all["country"].dropna().empty:
1524+
pytest.skip("No 'country' column with data in this fixture.")
1525+
1526+
# Pick an actual country value and change its case.
1527+
real_country = df_all["country"].dropna().iloc[0]
1528+
wrong_case_country = real_country.lower()
1529+
# If lowercasing didn't actually change the string, use upper instead.
1530+
if wrong_case_country == real_country:
1531+
wrong_case_country = real_country.upper()
1532+
1533+
# The wrong-cased query should emit a UserWarning with fuzzy suggestions.
1534+
with pytest.warns(UserWarning, match="Did you mean"):
1535+
df = api.sample_metadata(sample_query=f"country == '{wrong_case_country}'")
1536+
assert len(df) == 0
1537+
1538+
1539+
@parametrize_with_cases("fixture,api", cases=".")
1540+
def test_sample_metadata_no_warning_on_valid_query(
1541+
fixture, api: AnophelesSampleMetadata
1542+
):
1543+
"""Test that no spurious warning is emitted when a valid query returns results."""
1544+
df_all = api.sample_metadata()
1545+
if "country" not in df_all.columns or df_all["country"].dropna().empty:
1546+
pytest.skip("No 'country' column with data in this fixture.")
1547+
1548+
real_country = df_all["country"].dropna().iloc[0]
1549+
1550+
import warnings
1551+
1552+
with warnings.catch_warnings():
1553+
warnings.simplefilter("error", UserWarning)
1554+
df = api.sample_metadata(sample_query=f"country == '{real_country}'")
1555+
assert len(df) > 0
1556+
1557+
15111558
@parametrize_with_cases("fixture,api", cases=case_ag3_sim)
15121559
def test_cohort_data_admin1_year(fixture, api):
15131560
df_cohorts = api.cohorts("admin1_year")

0 commit comments

Comments
 (0)