Skip to content

Commit b89e20b

Browse files
committed
Add filter_unassigned parameter to _prep_samples_for_cohort_grouping
Fixes #806. When building cohorts, the function previously always filtered out samples with 'intermediate' or 'unassigned' taxon values. This was surprising when users specified a custom taxon_by column. The new filter_unassigned parameter (default None) auto-detects: - When taxon_by='taxon' (default): filters as before (backward compat) - When taxon_by is custom: preserves all values - Users can explicitly override with True/False Propagated through snp_allele_frequencies_advanced(), aa_allele_frequencies_advanced(), gene_cnv_frequencies_advanced(), and haplotypes_frequencies_advanced().
1 parent 2d3d2f9 commit b89e20b

6 files changed

Lines changed: 141 additions & 11 deletions

File tree

malariagen_data/anoph/cnv_frq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def gene_cnv_frequencies_advanced(
446446
chunks: base_params.chunks = base_params.native_chunks,
447447
inline_array: base_params.inline_array = base_params.inline_array_default,
448448
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
449+
filter_unassigned: Optional[frq_params.filter_unassigned] = None,
449450
) -> xr.Dataset:
450451
regions: List[Region] = _parse_multi_region(self, region)
451452
del region
@@ -468,6 +469,7 @@ def gene_cnv_frequencies_advanced(
468469
chunks=chunks,
469470
inline_array=inline_array,
470471
taxon_by=taxon_by,
472+
filter_unassigned=filter_unassigned,
471473
)
472474
for r in regions
473475
],
@@ -497,6 +499,7 @@ def _gene_cnv_frequencies_advanced(
497499
chunks,
498500
inline_array,
499501
taxon_by,
502+
filter_unassigned,
500503
):
501504
debug = self._log.debug
502505

@@ -527,6 +530,7 @@ def _gene_cnv_frequencies_advanced(
527530
area_by=area_by,
528531
period_by=period_by,
529532
taxon_by=taxon_by,
533+
filter_unassigned=filter_unassigned,
530534
)
531535

532536
debug("group samples to make cohorts")

malariagen_data/anoph/frq_base.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,29 @@
1616
from .base import AnophelesBase
1717

1818

19-
def _prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by, taxon_by):
19+
def _prep_samples_for_cohort_grouping(
20+
*, df_samples, area_by, period_by, taxon_by, filter_unassigned=None
21+
):
2022
# Take a copy, as we will modify the dataframe.
2123
df_samples = df_samples.copy()
2224

23-
# Fix "intermediate" or "unassigned" taxon values - we only want to build
24-
# cohorts with clean taxon calls, so we set other values to None.
25-
loc_intermediate_taxon = (
26-
df_samples[taxon_by].str.startswith("intermediate").fillna(False)
27-
)
28-
df_samples.loc[loc_intermediate_taxon, taxon_by] = None
29-
loc_unassigned_taxon = (
30-
df_samples[taxon_by].str.startswith("unassigned").fillna(False)
31-
)
32-
df_samples.loc[loc_unassigned_taxon, taxon_by] = None
25+
# Determine whether to filter "intermediate"/"unassigned" taxon values.
26+
# When filter_unassigned is None (default), auto-apply filtering only
27+
# when using the default "taxon" column. Users can explicitly override
28+
# with True/False.
29+
# See: https://github.com/malariagen/malariagen-data-python/issues/806
30+
if filter_unassigned is None:
31+
filter_unassigned = taxon_by == "taxon"
32+
33+
if filter_unassigned:
34+
loc_intermediate_taxon = (
35+
df_samples[taxon_by].str.startswith("intermediate").fillna(False)
36+
)
37+
df_samples.loc[loc_intermediate_taxon, taxon_by] = None
38+
loc_unassigned_taxon = (
39+
df_samples[taxon_by].str.startswith("unassigned").fillna(False)
40+
)
41+
df_samples.loc[loc_unassigned_taxon, taxon_by] = None
3342

3443
# Add period column.
3544

malariagen_data/anoph/frq_params.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,13 @@
8787
]
8888

8989
taxon_by_default: taxon_by = "taxon"
90+
91+
filter_unassigned: TypeAlias = Annotated[
92+
Optional[bool],
93+
"""
94+
Whether to filter out samples with "intermediate" or "unassigned" taxon
95+
values before building cohorts. If None (default), filtering is applied
96+
only when using the default "taxon" column. Set True to always filter,
97+
or False to never filter.
98+
""",
99+
]

malariagen_data/anoph/hap_frq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def haplotypes_frequencies_advanced(
154154
chunks: base_params.chunks = base_params.native_chunks,
155155
inline_array: base_params.inline_array = base_params.inline_array_default,
156156
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
157+
filter_unassigned: Optional[frq_params.filter_unassigned] = None,
157158
) -> xr.Dataset:
158159
# Load sample metadata.
159160
df_samples = self.sample_metadata(
@@ -168,6 +169,7 @@ def haplotypes_frequencies_advanced(
168169
area_by=area_by,
169170
period_by=period_by,
170171
taxon_by=taxon_by,
172+
filter_unassigned=filter_unassigned,
171173
)
172174

173175
# Group samples to make cohorts.

malariagen_data/anoph/snp_frq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def snp_allele_frequencies_advanced(
453453
chunks: base_params.chunks = base_params.native_chunks,
454454
inline_array: base_params.inline_array = base_params.inline_array_default,
455455
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
456+
filter_unassigned: Optional[frq_params.filter_unassigned] = None,
456457
) -> xr.Dataset:
457458
# Load sample metadata.
458459
df_samples = self.sample_metadata(
@@ -467,6 +468,7 @@ def snp_allele_frequencies_advanced(
467468
area_by=area_by,
468469
period_by=period_by,
469470
taxon_by=taxon_by,
471+
filter_unassigned=filter_unassigned,
470472
)
471473

472474
# Group samples to make cohorts.
@@ -684,6 +686,7 @@ def aa_allele_frequencies_advanced(
684686
chunks: base_params.chunks = base_params.native_chunks,
685687
inline_array: base_params.inline_array = base_params.inline_array_default,
686688
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
689+
filter_unassigned: Optional[frq_params.filter_unassigned] = None,
687690
) -> xr.Dataset:
688691
# Begin by computing SNP allele frequencies.
689692
ds_snp_frq = self.snp_allele_frequencies_advanced(
@@ -702,6 +705,7 @@ def aa_allele_frequencies_advanced(
702705
chunks=chunks,
703706
inline_array=inline_array,
704707
taxon_by=taxon_by,
708+
filter_unassigned=filter_unassigned,
705709
)
706710

707711
# N.B., we need to worry about the possibility of the

tests/anoph/test_frq_base.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Tests for _prep_samples_for_cohort_grouping filter_unassigned behavior.
2+
3+
See: https://github.com/malariagen/malariagen-data-python/issues/806
4+
"""
5+
6+
import pandas as pd
7+
8+
from malariagen_data.anoph.frq_base import _prep_samples_for_cohort_grouping
9+
10+
11+
def _make_test_df(taxon_col="taxon"):
12+
"""Create a test DataFrame with intermediate and unassigned taxon values."""
13+
return pd.DataFrame(
14+
{
15+
taxon_col: [
16+
"gambiae",
17+
"intermediate_gambcolu_arabiensis",
18+
"unassigned",
19+
"coluzzii",
20+
],
21+
"admin1_iso": ["KE-01", "KE-01", "KE-02", "KE-02"],
22+
"year": [2020, 2020, 2020, 2020],
23+
"month": [1, 1, 1, 1],
24+
}
25+
)
26+
27+
28+
class TestPrepSamplesFilterUnassigned:
29+
"""Tests for the filter_unassigned parameter in _prep_samples_for_cohort_grouping."""
30+
31+
def test_default_taxon_column_filters(self):
32+
"""When taxon_by='taxon' and filter_unassigned=None (default),
33+
intermediate/unassigned values should be set to None (backward compat)."""
34+
df = _make_test_df(taxon_col="taxon")
35+
result = _prep_samples_for_cohort_grouping(
36+
df_samples=df,
37+
area_by="admin1_iso",
38+
period_by="year",
39+
taxon_by="taxon",
40+
)
41+
assert result["taxon"].iloc[0] == "gambiae"
42+
assert result["taxon"].iloc[1] is None
43+
assert result["taxon"].iloc[2] is None
44+
assert result["taxon"].iloc[3] == "coluzzii"
45+
46+
def test_custom_column_preserves(self):
47+
"""When taxon_by is a custom column and filter_unassigned=None (default),
48+
intermediate/unassigned values should be preserved."""
49+
df = _make_test_df(taxon_col="custom_taxon")
50+
result = _prep_samples_for_cohort_grouping(
51+
df_samples=df,
52+
area_by="admin1_iso",
53+
period_by="year",
54+
taxon_by="custom_taxon",
55+
)
56+
assert result["custom_taxon"].iloc[0] == "gambiae"
57+
assert result["custom_taxon"].iloc[1] == "intermediate_gambcolu_arabiensis"
58+
assert result["custom_taxon"].iloc[2] == "unassigned"
59+
assert result["custom_taxon"].iloc[3] == "coluzzii"
60+
61+
def test_explicit_filter_true(self):
62+
"""When filter_unassigned=True, always filter regardless of column name."""
63+
df = _make_test_df(taxon_col="custom_taxon")
64+
result = _prep_samples_for_cohort_grouping(
65+
df_samples=df,
66+
area_by="admin1_iso",
67+
period_by="year",
68+
taxon_by="custom_taxon",
69+
filter_unassigned=True,
70+
)
71+
assert result["custom_taxon"].iloc[0] == "gambiae"
72+
assert result["custom_taxon"].iloc[1] is None
73+
assert result["custom_taxon"].iloc[2] is None
74+
assert result["custom_taxon"].iloc[3] == "coluzzii"
75+
76+
def test_explicit_filter_false(self):
77+
"""When filter_unassigned=False, never filter even for default 'taxon' column."""
78+
df = _make_test_df(taxon_col="taxon")
79+
result = _prep_samples_for_cohort_grouping(
80+
df_samples=df,
81+
area_by="admin1_iso",
82+
period_by="year",
83+
taxon_by="taxon",
84+
filter_unassigned=False,
85+
)
86+
assert result["taxon"].iloc[0] == "gambiae"
87+
assert result["taxon"].iloc[1] == "intermediate_gambcolu_arabiensis"
88+
assert result["taxon"].iloc[2] == "unassigned"
89+
assert result["taxon"].iloc[3] == "coluzzii"
90+
91+
def test_does_not_modify_original(self):
92+
"""Ensure the original DataFrame is not modified."""
93+
df = _make_test_df(taxon_col="taxon")
94+
original_values = df["taxon"].tolist()
95+
_prep_samples_for_cohort_grouping(
96+
df_samples=df,
97+
area_by="admin1_iso",
98+
period_by="year",
99+
taxon_by="taxon",
100+
)
101+
assert df["taxon"].tolist() == original_values

0 commit comments

Comments
 (0)