Skip to content

Commit cb18a86

Browse files
authored
Merge branch 'master' into feat/haplotype-sharing-arc-chord
2 parents f0d5d12 + 63a8201 commit cb18a86

14 files changed

Lines changed: 475 additions & 283 deletions

File tree

.github/actions/setup-python/action.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ runs:
1919
shell: bash
2020
run: |
2121
poetry env use ${{ inputs.python-version }}
22-
poetry install --extras dev
22+
poetry install --with dev,test,docs

CONTRIBUTING.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ Both of these can be installed using your distribution's package manager or [Hom
5252

5353
```bash
5454
poetry env use 3.12
55-
poetry install --extras dev
55+
poetry install --with dev,test,docs
5656
```
5757

58+
This installs the runtime dependencies along with the `dev`, `test`, and `docs`
59+
[dependency groups](https://python-poetry.org/docs/managing-dependencies/#dependency-groups).
60+
If you only need to run tests, `poetry install --with test` is sufficient.
61+
5862
**Recommended**: Use `poetry run` to run commands inside the virtual environment:
5963

6064
```bash

malariagen_data/anoph/base_params.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@
6969
str,
7070
"""
7171
A pandas query string to be evaluated against the sample metadata, to
72-
select samples to be included in the returned data.
72+
select samples to be included in the returned data. E.g.,
73+
"country == 'Uganda'". If the query returns zero results, a warning
74+
will be emitted with fuzzy-match suggestions for possible typos or
75+
case mismatches.
7376
""",
7477
]
7578

malariagen_data/anoph/frq_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ def _prep_samples_for_cohort_grouping(
7878
# Apply the matching period_by function to create a new "period" column.
7979
df_samples["period"] = df_samples.apply(period_by_func, axis="columns")
8080

81+
# Validate area_by.
82+
if area_by not in df_samples.columns:
83+
raise ValueError(
84+
f"Invalid value for `area_by`: {area_by!r}. "
85+
f"Must be the name of an existing column in the sample metadata."
86+
)
87+
8188
# Copy the specified area_by column to a new "area" column.
8289
df_samples["area"] = df_samples[area_by]
8390

malariagen_data/anoph/sample_metadata.py

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import difflib
12
import io
23
import json
4+
import re
35
from itertools import cycle
46
from typing import (
57
Any,
@@ -82,6 +84,7 @@ def __init__(
8284

8385
# Initialize cache attributes.
8486
self._cache_sample_metadata: Dict = dict()
87+
self._cache_cohorts: Dict = dict()
8588
self._cache_cohort_geometries: Dict = dict()
8689

8790
def _metadata_paths(
@@ -704,6 +707,17 @@ def clear_extra_metadata(self):
704707
@doc(
705708
summary="Access sample metadata for one or more sample sets.",
706709
returns="A dataframe of sample metadata, one row per sample.",
710+
notes="""
711+
Some samples in the dataset are lab crosses — mosquitoes bred in
712+
the laboratory that have no real collection date. These samples
713+
use ``year=-1`` and ``month=-1`` as sentinel values. They may
714+
cause unexpected results in date-based analyses (e.g.,
715+
``pd.to_datetime`` will fail on negative year values).
716+
717+
To exclude lab cross samples, use::
718+
719+
df = api.sample_metadata(sample_query="year >= 0")
720+
""",
707721
)
708722
def sample_metadata(
709723
self,
@@ -783,12 +797,65 @@ def sample_metadata(
783797
if prepared_sample_query is not None:
784798
# Assume a pandas query string.
785799
sample_query_options = sample_query_options or {}
800+
801+
# Save a reference to the pre-query DataFrame so we can detect
802+
# zero-result queries and provide a helpful warning.
803+
df_before_query = df_samples
804+
786805
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
787806
df_samples = df_samples.query(
788807
prepared_sample_query, **sample_query_options, engine="python"
789808
)
790809
df_samples = df_samples.reset_index(drop=True)
791810

811+
# Warn if query returned zero results on a non-empty dataset.
812+
# Provide fuzzy-match suggestions so users can spot typos,
813+
# case mismatches, or partial-value issues.
814+
if len(df_samples) == 0 and len(df_before_query) > 0:
815+
hint_lines = [
816+
f"sample_metadata() returned 0 samples for query: {prepared_sample_query!r}.",
817+
]
818+
819+
# Extract column == 'value' pairs from the query.
820+
col_val_pairs = re.findall(
821+
r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]",
822+
prepared_sample_query,
823+
)
824+
825+
for col_name, queried_val in col_val_pairs:
826+
# If the column name is not recognised, suggest
827+
# close column names.
828+
if col_name not in df_before_query.columns:
829+
close_cols = difflib.get_close_matches(
830+
col_name,
831+
df_before_query.columns.tolist(),
832+
n=3,
833+
cutoff=0.6,
834+
)
835+
if close_cols:
836+
hint_lines.append(
837+
f"Column {col_name!r} not found. "
838+
f"Did you mean: {close_cols}?"
839+
)
840+
continue
841+
842+
# For string columns, suggest close values.
843+
if df_before_query[col_name].dtype == object:
844+
valid_vals = (
845+
df_before_query[col_name].dropna().unique().tolist()
846+
)
847+
close_vals = difflib.get_close_matches(
848+
queried_val, valid_vals, n=5, cutoff=0.6
849+
)
850+
if close_vals:
851+
hint_lines.append(
852+
f"Value {queried_val!r} not found in "
853+
f"column {col_name!r}. "
854+
f"Did you mean: {close_vals}?"
855+
)
856+
857+
warnings.warn("\n".join(hint_lines), UserWarning, stacklevel=2)
858+
792859
# Apply the sample_indices, if there are any.
793860
# Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
794861
if sample_indices is not None:
@@ -1487,7 +1554,11 @@ def _setup_cohort_queries(
14871554
A cohort set name. Accepted values are:
14881555
"admin1_month", "admin1_quarter", "admin1_year",
14891556
"admin2_month", "admin2_quarter", "admin2_year".
1490-
"""
1557+
""",
1558+
query="""
1559+
An optional pandas query string to filter the resulting
1560+
dataframe, e.g., "country == 'Burkina Faso'".
1561+
""",
14911562
),
14921563
returns="""A dataframe of cohort data, one row per cohort. There are up to 18 columns:
14931564
`cohort_id` is the identifier of the cohort,
@@ -1514,20 +1585,45 @@ def _setup_cohort_queries(
15141585
def cohorts(
15151586
self,
15161587
cohort_set: base_params.cohorts,
1588+
query: Optional[str] = None,
15171589
) -> pd.DataFrame:
1518-
major_version_path = self._major_version_path
1590+
valid_cohort_sets = {
1591+
"admin1_month",
1592+
"admin1_quarter",
1593+
"admin1_year",
1594+
"admin2_month",
1595+
"admin2_quarter",
1596+
"admin2_year",
1597+
}
1598+
if cohort_set not in valid_cohort_sets:
1599+
raise ValueError(
1600+
f"{cohort_set!r} is not a valid cohort set. "
1601+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1602+
)
1603+
15191604
cohorts_analysis = self._cohorts_analysis
15201605

1521-
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1606+
# Cache to avoid repeated reads.
1607+
cache_key = (cohorts_analysis, cohort_set)
1608+
try:
1609+
df_cohorts = self._cache_cohorts[cache_key]
1610+
except KeyError:
1611+
major_version_path = self._major_version_path
1612+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1613+
1614+
with self.open_file(path) as f:
1615+
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1616+
1617+
# Ensure all column names are lower case.
1618+
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
15221619

1523-
# Read the manifest into a pandas dataframe.
1524-
with self.open_file(path) as f:
1525-
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1620+
self._cache_cohorts[cache_key] = df_cohorts
15261621

1527-
# Ensure all column names are lower case.
1528-
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
1622+
if query is not None:
1623+
df_cohorts = df_cohorts.query(query)
1624+
df_cohorts = df_cohorts.reset_index(drop=True)
15291625

1530-
return df_cohorts
1626+
return df_cohorts.copy()
15311627

15321628
@_check_types
15331629
@doc(

malariagen_data/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,9 +855,7 @@ def _value_error(
855855
value,
856856
expectation,
857857
):
858-
message = (
859-
f"Bad value for parameter {name}; expected {expectation}, " f"found {value!r}"
860-
)
858+
message = f"Bad value for parameter {name}; expected {expectation}, found {value!r}"
861859
raise ValueError(message)
862860

863861

@@ -935,6 +933,7 @@ def info(self, msg):
935933
self.flush()
936934

937935
def set_level(self, level):
936+
self._logger.setLevel(level)
938937
if self._handler is not None:
939938
self._handler.setLevel(level)
940939

0 commit comments

Comments
 (0)