|
| 1 | +import difflib |
1 | 2 | import io |
2 | 3 | import json |
3 | 4 | import re |
@@ -797,42 +798,49 @@ def sample_metadata( |
797 | 798 | df_samples = df_samples.reset_index(drop=True) |
798 | 799 |
|
799 | 800 | # Warn if query returned zero results on a non-empty dataset. |
800 | | - # This helps users catch case-sensitivity issues in string queries, |
801 | | - # e.g. "country == 'uganda'" instead of "country == 'Uganda'". |
| 801 | + # Provide fuzzy-match suggestions so users can spot typos, |
| 802 | + # case mismatches, or partial-value issues. |
802 | 803 | if len(df_samples) == 0 and len(df_before_query) > 0: |
803 | | - # Extract column names from comparison expressions in the query. |
804 | | - # Match patterns like: column == 'value' or column == "value" |
805 | | - referenced_cols = re.findall( |
806 | | - r"\b(\w+)\s*[=!<>]+\s*['\"]", prepared_sample_query |
807 | | - ) |
808 | | - |
809 | 804 | hint_lines = [ |
810 | | - f"sample_metadata() returned 0 samples for the given query: {prepared_sample_query!r}.", |
| 805 | + f"sample_metadata() returned 0 samples for query: {prepared_sample_query!r}.", |
811 | 806 | ] |
812 | 807 |
|
813 | | - # Only add the case-sensitivity hint when the query |
814 | | - # contains quoted string literals (not numeric-only). |
815 | | - if re.search(r"['\"].+?['\"]", prepared_sample_query): |
816 | | - hint_lines.append( |
817 | | - "Note: string comparisons in sample_query are case-sensitive." |
818 | | - ) |
| 808 | + # Extract column == 'value' pairs from the query. |
| 809 | + col_val_pairs = re.findall( |
| 810 | + r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]", |
| 811 | + prepared_sample_query, |
| 812 | + ) |
819 | 813 |
|
820 | | - # For each referenced string column, list valid values. |
821 | | - for col in dict.fromkeys(referenced_cols): # deduplicate |
822 | | - if ( |
823 | | - col in df_before_query.columns |
824 | | - and df_before_query[col].dtype == object |
825 | | - ): |
826 | | - valid_vals = sorted( |
827 | | - df_before_query[col].dropna().unique().tolist() |
| 814 | + for col_name, queried_val in col_val_pairs: |
| 815 | + # If the column name is not recognised, suggest |
| 816 | + # close column names. |
| 817 | + if col_name not in df_before_query.columns: |
| 818 | + close_cols = difflib.get_close_matches( |
| 819 | + col_name, |
| 820 | + df_before_query.columns.tolist(), |
| 821 | + n=3, |
| 822 | + cutoff=0.6, |
828 | 823 | ) |
829 | | - if len(valid_vals) > 20: |
| 824 | + if close_cols: |
830 | 825 | hint_lines.append( |
831 | | - f"Valid values for column {col!r} (showing 20 of {len(valid_vals)}): {valid_vals[:20]}" |
| 826 | + f"Column {col_name!r} not found. " |
| 827 | + f"Did you mean: {close_cols}?" |
832 | 828 | ) |
833 | | - else: |
| 829 | + continue |
| 830 | + |
| 831 | + # For string columns, suggest close values. |
| 832 | + if df_before_query[col_name].dtype == object: |
| 833 | + valid_vals = ( |
| 834 | + df_before_query[col_name].dropna().unique().tolist() |
| 835 | + ) |
| 836 | + close_vals = difflib.get_close_matches( |
| 837 | + queried_val, valid_vals, n=5, cutoff=0.6 |
| 838 | + ) |
| 839 | + if close_vals: |
834 | 840 | hint_lines.append( |
835 | | - f"Valid values for column {col!r}: {valid_vals}" |
| 841 | + f"Value {queried_val!r} not found in " |
| 842 | + f"column {col_name!r}. " |
| 843 | + f"Did you mean: {close_vals}?" |
836 | 844 | ) |
837 | 845 |
|
838 | 846 | warnings.warn("\n".join(hint_lines), UserWarning, stacklevel=2) |
|
0 commit comments