Skip to content

Commit 5334be8

Browse files
Merge branch 'master' into GH1120-fix-pipx-link-in-contributing
2 parents 428d678 + 84d190d commit 5334be8

19 files changed

Lines changed: 882 additions & 289 deletions

.github/actions/setup-python/action.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ runs:
1919
shell: bash
2020
run: |
2121
poetry env use ${{ inputs.python-version }}
22-
poetry install --extras dev
22+
poetry install --with dev,test,docs

CONTRIBUTING.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ Both of these can be installed using your distribution's package manager or [Hom
5252

5353
```bash
5454
poetry env use 3.12
55-
poetry install --extras dev
55+
poetry install --with dev,test,docs
5656
```
5757

58+
This installs the runtime dependencies along with the `dev`, `test`, and `docs`
59+
[dependency groups](https://python-poetry.org/docs/managing-dependencies/#dependency-groups).
60+
If you only need to run tests, `poetry install --with test` is sufficient.
61+
5862
**Recommended**: Use `poetry run` to run commands inside the virtual environment:
5963

6064
```bash

malariagen_data/anoph/base_params.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@
6969
str,
7070
"""
7171
A pandas query string to be evaluated against the sample metadata, to
72-
select samples to be included in the returned data.
72+
select samples to be included in the returned data. E.g.,
73+
"country == 'Uganda'". If the query returns zero results, a warning
74+
will be emitted with fuzzy-match suggestions for possible typos or
75+
case mismatches.
7376
""",
7477
]
7578

malariagen_data/anoph/frq_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def _prep_samples_for_cohort_grouping(
2929
# Users can explicitly override with True/False.
3030
filter_unassigned = taxon_by == "taxon"
3131

32+
# Validate taxon_by.
33+
if taxon_by not in df_samples.columns:
34+
raise ValueError(
35+
f"Invalid value for `taxon_by`: {taxon_by!r}. "
36+
f"Must be the name of an existing column in the sample metadata."
37+
)
38+
3239
if filter_unassigned:
3340
# Remove samples with "intermediate" or "unassigned" taxon values,
3441
# as we only want cohorts with clean taxon calls.
@@ -78,6 +85,13 @@ def _prep_samples_for_cohort_grouping(
7885
# Apply the matching period_by function to create a new "period" column.
7986
df_samples["period"] = df_samples.apply(period_by_func, axis="columns")
8087

88+
# Validate area_by.
89+
if area_by not in df_samples.columns:
90+
raise ValueError(
91+
f"Invalid value for `area_by`: {area_by!r}. "
92+
f"Must be the name of an existing column in the sample metadata."
93+
)
94+
8195
# Copy the specified area_by column to a new "area" column.
8296
df_samples["area"] = df_samples[area_by]
8397

malariagen_data/anoph/sample_metadata.py

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import difflib
12
import io
3+
import json
4+
import re
25
from itertools import cycle
36
from typing import (
47
Any,
@@ -81,6 +84,8 @@ def __init__(
8184

8285
# Initialize cache attributes.
8386
self._cache_sample_metadata: Dict = dict()
87+
self._cache_cohorts: Dict = dict()
88+
self._cache_cohort_geometries: Dict = dict()
8489

8590
def _metadata_paths(
8691
self,
@@ -702,6 +707,17 @@ def clear_extra_metadata(self):
702707
@doc(
703708
summary="Access sample metadata for one or more sample sets.",
704709
returns="A dataframe of sample metadata, one row per sample.",
710+
notes="""
711+
Some samples in the dataset are lab crosses — mosquitoes bred in
712+
the laboratory that have no real collection date. These samples
713+
use ``year=-1`` and ``month=-1`` as sentinel values. They may
714+
cause unexpected results in date-based analyses (e.g.,
715+
``pd.to_datetime`` will fail on negative year values).
716+
717+
To exclude lab cross samples, use::
718+
719+
df = api.sample_metadata(sample_query="year >= 0")
720+
""",
705721
)
706722
def sample_metadata(
707723
self,
@@ -781,12 +797,65 @@ def sample_metadata(
781797
if prepared_sample_query is not None:
782798
# Assume a pandas query string.
783799
sample_query_options = sample_query_options or {}
800+
801+
# Save a reference to the pre-query DataFrame so we can detect
802+
# zero-result queries and provide a helpful warning.
803+
df_before_query = df_samples
804+
784805
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
785806
df_samples = df_samples.query(
786807
prepared_sample_query, **sample_query_options, engine="python"
787808
)
788809
df_samples = df_samples.reset_index(drop=True)
789810

811+
# Warn if query returned zero results on a non-empty dataset.
812+
# Provide fuzzy-match suggestions so users can spot typos,
813+
# case mismatches, or partial-value issues.
814+
if len(df_samples) == 0 and len(df_before_query) > 0:
815+
hint_lines = [
816+
f"sample_metadata() returned 0 samples for query: {prepared_sample_query!r}.",
817+
]
818+
819+
# Extract column == 'value' pairs from the query.
820+
col_val_pairs = re.findall(
821+
r"\b(\w+)\s*==\s*['\"]([^'\"]+)['\"]",
822+
prepared_sample_query,
823+
)
824+
825+
for col_name, queried_val in col_val_pairs:
826+
# If the column name is not recognised, suggest
827+
# close column names.
828+
if col_name not in df_before_query.columns:
829+
close_cols = difflib.get_close_matches(
830+
col_name,
831+
df_before_query.columns.tolist(),
832+
n=3,
833+
cutoff=0.6,
834+
)
835+
if close_cols:
836+
hint_lines.append(
837+
f"Column {col_name!r} not found. "
838+
f"Did you mean: {close_cols}?"
839+
)
840+
continue
841+
842+
# For string columns, suggest close values.
843+
if df_before_query[col_name].dtype == object:
844+
valid_vals = (
845+
df_before_query[col_name].dropna().unique().tolist()
846+
)
847+
close_vals = difflib.get_close_matches(
848+
queried_val, valid_vals, n=5, cutoff=0.6
849+
)
850+
if close_vals:
851+
hint_lines.append(
852+
f"Value {queried_val!r} not found in "
853+
f"column {col_name!r}. "
854+
f"Did you mean: {close_vals}?"
855+
)
856+
857+
warnings.warn("\n".join(hint_lines), UserWarning, stacklevel=2)
858+
790859
# Apply the sample_indices, if there are any.
791860
# Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
792861
if sample_indices is not None:
@@ -1485,7 +1554,11 @@ def _setup_cohort_queries(
14851554
A cohort set name. Accepted values are:
14861555
"admin1_month", "admin1_quarter", "admin1_year",
14871556
"admin2_month", "admin2_quarter", "admin2_year".
1488-
"""
1557+
""",
1558+
query="""
1559+
An optional pandas query string to filter the resulting
1560+
dataframe, e.g., "country == 'Burkina Faso'".
1561+
""",
14891562
),
14901563
returns="""A dataframe of cohort data, one row per cohort. There are up to 18 columns:
14911564
`cohort_id` is the identifier of the cohort,
@@ -1512,20 +1585,98 @@ def _setup_cohort_queries(
15121585
def cohorts(
15131586
self,
15141587
cohort_set: base_params.cohorts,
1588+
query: Optional[str] = None,
15151589
) -> pd.DataFrame:
1516-
major_version_path = self._major_version_path
1590+
valid_cohort_sets = {
1591+
"admin1_month",
1592+
"admin1_quarter",
1593+
"admin1_year",
1594+
"admin2_month",
1595+
"admin2_quarter",
1596+
"admin2_year",
1597+
}
1598+
if cohort_set not in valid_cohort_sets:
1599+
raise ValueError(
1600+
f"{cohort_set!r} is not a valid cohort set. "
1601+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1602+
)
1603+
15171604
cohorts_analysis = self._cohorts_analysis
15181605

1519-
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1606+
# Cache to avoid repeated reads.
1607+
cache_key = (cohorts_analysis, cohort_set)
1608+
try:
1609+
df_cohorts = self._cache_cohorts[cache_key]
1610+
except KeyError:
1611+
major_version_path = self._major_version_path
1612+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.csv"
1613+
1614+
with self.open_file(path) as f:
1615+
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1616+
1617+
# Ensure all column names are lower case.
1618+
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
1619+
1620+
self._cache_cohorts[cache_key] = df_cohorts
1621+
1622+
if query is not None:
1623+
df_cohorts = df_cohorts.query(query)
1624+
df_cohorts = df_cohorts.reset_index(drop=True)
1625+
1626+
return df_cohorts.copy()
1627+
1628+
@_check_types
1629+
@doc(
1630+
summary="""
1631+
Read GeoJSON geometry data for a specific cohort set,
1632+
providing boundary geometries for each cohort.
1633+
""",
1634+
parameters=dict(
1635+
cohort_set="""
1636+
A cohort set name. Accepted values are:
1637+
"admin1_month", "admin1_quarter", "admin1_year",
1638+
"admin2_month", "admin2_quarter", "admin2_year".
1639+
""",
1640+
),
1641+
returns="""
1642+
A dict containing the parsed GeoJSON FeatureCollection,
1643+
with boundary geometries for each cohort in the set.
1644+
""",
1645+
)
1646+
def cohort_geometries(
1647+
self,
1648+
cohort_set: base_params.cohorts,
1649+
) -> dict:
1650+
valid_cohort_sets = {
1651+
"admin1_month",
1652+
"admin1_quarter",
1653+
"admin1_year",
1654+
"admin2_month",
1655+
"admin2_quarter",
1656+
"admin2_year",
1657+
}
1658+
if cohort_set not in valid_cohort_sets:
1659+
raise ValueError(
1660+
f"{cohort_set!r} is not a valid cohort set. "
1661+
f"Accepted values are: {sorted(valid_cohort_sets)}."
1662+
)
1663+
1664+
cohorts_analysis = self._cohorts_analysis
1665+
1666+
# Cache to avoid repeated reads.
1667+
cache_key = (cohorts_analysis, cohort_set)
1668+
try:
1669+
geojson_data = self._cache_cohort_geometries[cache_key]
1670+
except KeyError:
1671+
major_version_path = self._major_version_path
1672+
path = f"{major_version_path[:2]}_cohorts/cohorts_{cohorts_analysis}/cohorts_{cohort_set}.geojson"
15201673

1521-
# Read the manifest into a pandas dataframe.
1522-
with self.open_file(path) as f:
1523-
df_cohorts = pd.read_csv(f, sep=",", na_values="")
1674+
with self.open_file(path) as f:
1675+
geojson_data = json.load(f)
15241676

1525-
# Ensure all column names are lower case.
1526-
df_cohorts.columns = [c.lower() for c in df_cohorts.columns] # type: ignore
1677+
self._cache_cohort_geometries[cache_key] = geojson_data
15271678

1528-
return df_cohorts
1679+
return geojson_data
15291680

15301681
@_check_types
15311682
@doc(

malariagen_data/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,9 +855,7 @@ def _value_error(
855855
value,
856856
expectation,
857857
):
858-
message = (
859-
f"Bad value for parameter {name}; expected {expectation}, " f"found {value!r}"
860-
)
858+
message = f"Bad value for parameter {name}; expected {expectation}, found {value!r}"
861859
raise ValueError(message)
862860

863861

@@ -935,6 +933,7 @@ def info(self, msg):
935933
self.flush()
936934

937935
def set_level(self, level):
936+
self._logger.setLevel(level)
938937
if self._handler is not None:
939938
self._handler.setLevel(level)
940939

0 commit comments

Comments
 (0)