Skip to content

Commit 5ac7b66

Browse files
committed
feat: use difflib fuzzy matching for zero-result query suggestions
1 parent a4f12e1 commit 5ac7b66

3 files changed

Lines changed: 45 additions & 35 deletions

File tree

malariagen_data/anoph/base_params.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@
7070
"""
7171
A pandas query string to be evaluated against the sample metadata, to
7272
select samples to be included in the returned data. E.g.,
73-
"country == 'Uganda'". Note: string comparisons are case-sensitive —
74-
column values must match the exact casing stored in the metadata
75-
(e.g., "Uganda" not "uganda"). A warning will be emitted if the query
76-
returns zero results.
73+
"country == 'Uganda'". If the query returns zero results, a warning
74+
will be emitted with fuzzy-match suggestions for possible typos or
75+
case mismatches.
7776
""",
7877
]
7978

malariagen_data/anoph/sample_metadata.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import difflib
12
import io
23
import json
34
import re
@@ -797,42 +798,49 @@ def sample_metadata(
797798
df_samples = df_samples.reset_index(drop=True)
798799

799800
# Warn if query returned zero results on a non-empty dataset.
800-
# This helps users catch case-sensitivity issues in string queries,
801-
# e.g. "country == 'uganda'" instead of "country == 'Uganda'".
801+
# Provide fuzzy-match suggestions so users can spot typos,
802+
# case mismatches, or partial-value issues.
802803
if len(df_samples) == 0 and len(df_before_query) > 0:
803-
# Extract column names from comparison expressions in the query.
804-
# Match patterns like: column == 'value' or column == "value"
805-
referenced_cols = re.findall(
806-
r"\b(\w+)\s*[=!<>]+\s*['\"]", prepared_sample_query
807-
)
808-
809804
hint_lines = [
810-
f"sample_metadata() returned 0 samples for the given query: {prepared_sample_query!r}.",
805+
f"sample_metadata() returned 0 samples for query: {prepared_sample_query!r}.",
811806
]
812807

813-
# Only add the case-sensitivity hint when the query
814-
# contains quoted string literals (not numeric-only).
815-
if re.search(r"['\"].+?['\"]", prepared_sample_query):
816-
hint_lines.append(
817-
"Note: string comparisons in sample_query are case-sensitive."
818-
)
808+
# Extract column == 'value' pairs from the query.
809+
col_val_pairs = re.findall(
810+
r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]",
811+
prepared_sample_query,
812+
)
819813

820-
# For each referenced string column, list valid values.
821-
for col in dict.fromkeys(referenced_cols): # deduplicate
822-
if (
823-
col in df_before_query.columns
824-
and df_before_query[col].dtype == object
825-
):
826-
valid_vals = sorted(
827-
df_before_query[col].dropna().unique().tolist()
814+
for col_name, queried_val in col_val_pairs:
815+
# If the column name is not recognised, suggest
816+
# close column names.
817+
if col_name not in df_before_query.columns:
818+
close_cols = difflib.get_close_matches(
819+
col_name,
820+
df_before_query.columns.tolist(),
821+
n=3,
822+
cutoff=0.6,
828823
)
829-
if len(valid_vals) > 20:
824+
if close_cols:
830825
hint_lines.append(
831-
f"Valid values for column {col!r} (showing 20 of {len(valid_vals)}): {valid_vals[:20]}"
826+
f"Column {col_name!r} not found. "
827+
f"Did you mean: {close_cols}?"
832828
)
833-
else:
829+
continue
830+
831+
# For string columns, suggest close values.
832+
if df_before_query[col_name].dtype == object:
833+
valid_vals = (
834+
df_before_query[col_name].dropna().unique().tolist()
835+
)
836+
close_vals = difflib.get_close_matches(
837+
queried_val, valid_vals, n=5, cutoff=0.6
838+
)
839+
if close_vals:
834840
hint_lines.append(
835-
f"Valid values for column {col!r}: {valid_vals}"
841+
f"Value {queried_val!r} not found in "
842+
f"column {col_name!r}. "
843+
f"Did you mean: {close_vals}?"
836844
)
837845

838846
warnings.warn("\n".join(hint_lines), UserWarning, stacklevel=2)

tests/anoph/test_sample_metadata.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,8 +1509,11 @@ def test_cohort_data(fixture, api):
15091509

15101510

15111511
@parametrize_with_cases("fixture,api", cases=".")
1512-
def test_sample_metadata_warns_on_case_mismatch(fixture, api: AnophelesSampleMetadata):
1513-
"""Test that a UserWarning is raised when a case-mismatched query returns 0 results.
1512+
def test_sample_metadata_warns_on_zero_results_with_suggestions(
1513+
fixture, api: AnophelesSampleMetadata
1514+
):
1515+
"""Test that a UserWarning with fuzzy suggestions is raised when a query
1516+
returns 0 results due to a typo or case mismatch.
15141517
15151518
Regression test for https://github.com/malariagen/malariagen-data-python/issues/1083
15161519
"""
@@ -1527,8 +1530,8 @@ def test_sample_metadata_warns_on_case_mismatch(fixture, api: AnophelesSampleMet
15271530
if wrong_case_country == real_country:
15281531
wrong_case_country = real_country.upper()
15291532

1530-
# The wrong-cased query should emit a UserWarning mentioning "case-sensitive".
1531-
with pytest.warns(UserWarning, match="case-sensitive"):
1533+
# The wrong-cased query should emit a UserWarning with fuzzy suggestions.
1534+
with pytest.warns(UserWarning, match="Did you mean"):
15321535
df = api.sample_metadata(sample_query=f"country == '{wrong_case_country}'")
15331536
assert len(df) == 0
15341537

0 commit comments

Comments
 (0)