Skip to content

Commit d53e326

Browse files
authored
Merge pull request #694 from malariagen/GH391_add_params_to_allele_frequencies_advanced
Add "taxon_by" param to `*_frequencies_advanced()` functions and allow the "period_by" param to specify a column name
2 parents 2dba673 + ad20786 commit d53e326

File tree

11 files changed

+348
-136
lines changed

11 files changed

+348
-136
lines changed

malariagen_data/anoph/cnv_frq.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ def gene_cnv_frequencies_advanced(
445445
ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default,
446446
chunks: base_params.chunks = base_params.native_chunks,
447447
inline_array: base_params.inline_array = base_params.inline_array_default,
448+
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
448449
) -> xr.Dataset:
449450
regions: List[Region] = parse_multi_region(self, region)
450451
del region
@@ -466,6 +467,7 @@ def gene_cnv_frequencies_advanced(
466467
ci_method=ci_method,
467468
chunks=chunks,
468469
inline_array=inline_array,
470+
taxon_by=taxon_by,
469471
)
470472
for r in regions
471473
],
@@ -494,6 +496,7 @@ def _gene_cnv_frequencies_advanced(
494496
ci_method,
495497
chunks,
496498
inline_array,
499+
taxon_by,
497500
):
498501
debug = self._log.debug
499502

@@ -523,15 +526,17 @@ def _gene_cnv_frequencies_advanced(
523526
df_samples=df_samples,
524527
area_by=area_by,
525528
period_by=period_by,
529+
taxon_by=taxon_by,
526530
)
527531

528532
debug("group samples to make cohorts")
529-
group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"])
533+
group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"])
530534

531535
debug("build cohorts dataframe")
532536
df_cohorts = build_cohorts_from_sample_grouping(
533537
group_samples_by_cohort=group_samples_by_cohort,
534538
min_cohort_size=min_cohort_size,
539+
taxon_by=taxon_by,
535540
)
536541

537542
debug("figure out expected copy number")
@@ -556,7 +561,8 @@ def _gene_cnv_frequencies_advanced(
556561
debug("build event count and nobs for each cohort")
557562
for cohort_index, cohort in enumerate(df_cohorts.itertuples()):
558563
# construct grouping key
559-
cohort_key = cohort.taxon, cohort.area, cohort.period
564+
cohort_taxon = getattr(cohort, taxon_by)
565+
cohort_key = cohort_taxon, cohort.area, cohort.period
560566

561567
# obtain sample indices for cohort
562568
sample_indices = group_samples_by_cohort.indices[cohort_key]

malariagen_data/anoph/frq_base.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pandas as pd
3+
import re
34
import xarray as xr
45
import plotly.express as px
56
from textwrap import dedent
@@ -14,42 +15,67 @@
1415
from .base import AnophelesBase
1516

1617

17-
def prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by):
18+
def prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by, taxon_by):
1819
# Take a copy, as we will modify the dataframe.
1920
df_samples = df_samples.copy()
2021

2122
# Fix "intermediate" or "unassigned" taxon values - we only want to build
2223
# cohorts with clean taxon calls, so we set other values to None.
2324
loc_intermediate_taxon = (
24-
df_samples["taxon"].str.startswith("intermediate").fillna(False)
25+
df_samples[taxon_by].str.startswith("intermediate").fillna(False)
2526
)
26-
df_samples.loc[loc_intermediate_taxon, "taxon"] = None
27+
df_samples.loc[loc_intermediate_taxon, taxon_by] = None
2728
loc_unassigned_taxon = (
28-
df_samples["taxon"].str.startswith("unassigned").fillna(False)
29+
df_samples[taxon_by].str.startswith("unassigned").fillna(False)
2930
)
30-
df_samples.loc[loc_unassigned_taxon, "taxon"] = None
31+
df_samples.loc[loc_unassigned_taxon, taxon_by] = None
3132

3233
# Add period column.
33-
if period_by == "year":
34-
make_period = _make_sample_period_year
35-
elif period_by == "quarter":
36-
make_period = _make_sample_period_quarter
37-
elif period_by == "month":
38-
make_period = _make_sample_period_month
39-
else: # pragma: no cover
40-
raise ValueError(
41-
f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}."
42-
)
43-
sample_period = df_samples.apply(make_period, axis="columns")
44-
df_samples["period"] = sample_period
4534

46-
# Add area column for consistent output.
35+
# Map supported period_by values to functions that return either the relevant pd.Period or pd.NaT per row.
36+
period_by_funcs = {
37+
"year": _make_sample_period_year,
38+
"quarter": _make_sample_period_quarter,
39+
"month": _make_sample_period_month,
40+
}
41+
42+
# Get the matching function for the specified period_by value, or None.
43+
period_by_func = period_by_funcs.get(period_by)
44+
45+
# If there were no matching functions for the specified period_by value...
46+
if period_by_func is None:
47+
# Raise a ValueError if the specified period_by value is not a column in the DataFrame.
48+
if period_by not in df_samples.columns:
49+
raise ValueError(
50+
f"Invalid value for `period_by`: {period_by!r}. Either specify the name of an existing column "
51+
"or a supported period: 'year', 'quarter', or 'month'."
52+
)
53+
54+
# Raise a ValueError if the specified period_by column does not contain instances pd.Period.
55+
if not all(
56+
df_samples[period_by].apply(
57+
lambda value: pd.isnull(value) or isinstance(value, pd.Period)
58+
)
59+
):
60+
raise TypeError(
61+
f"Invalid values in {period_by!r} column. Must be either pandas.Period or null."
62+
)
63+
64+
# Copy the specified period_by column to a new "period" column.
65+
df_samples["period"] = df_samples[period_by]
66+
else:
67+
# Apply the matching period_by function to create a new "period" column.
68+
df_samples["period"] = df_samples.apply(period_by_func, axis="columns")
69+
70+
# Copy the specified area_by column to a new "area" column.
4771
df_samples["area"] = df_samples[area_by]
4872

4973
return df_samples
5074

5175

52-
def build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_size):
76+
def build_cohorts_from_sample_grouping(
77+
*, group_samples_by_cohort, min_cohort_size, taxon_by
78+
):
5379
# Build cohorts dataframe.
5480
df_cohorts = group_samples_by_cohort.agg(
5581
size=("sample_id", len),
@@ -70,9 +96,16 @@ def build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_si
7096
df_cohorts["period_end"] = cohort_period_end
7197
# Create a label that is similar to the cohort metadata,
7298
# although this won't be perfect.
73-
df_cohorts["label"] = df_cohorts.apply(
74-
lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns"
75-
)
99+
if taxon_by == frq_params.taxon_by_default:
100+
df_cohorts["label"] = df_cohorts.apply(
101+
lambda v: f"{v.area}_{v[taxon_by][:4]}_{v.period}", axis="columns"
102+
)
103+
else:
104+
# Replace non-alphanumeric characters in the taxon with underscores.
105+
df_cohorts["label"] = df_cohorts.apply(
106+
lambda v: f"{v.area}_{re.sub(r'[^A-Za-z0-9]+', '_', str(v[taxon_by]))}_{v.period}",
107+
axis="columns",
108+
)
76109

77110
# Apply minimum cohort size.
78111
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)

malariagen_data/anoph/frq_params.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
]
2626

2727
period_by: TypeAlias = Annotated[
28-
Literal["year", "quarter", "month"],
29-
"Length of time to group samples temporally.",
28+
Union[str, Literal["year", "quarter", "month"]],
29+
"Either the length of time to group samples temporally or the name the column to use.",
3030
]
3131

3232
variant_query: TypeAlias = Annotated[
@@ -80,3 +80,10 @@
8080
Optional[Union[str, List[str], Tuple[str, ...]]],
8181
"The area or areas to restrict the dataset to.",
8282
]
83+
84+
taxon_by: TypeAlias = Annotated[
85+
str,
86+
"The column to use for taxon stratification.",
87+
]
88+
89+
taxon_by_default: taxon_by = "taxon"

malariagen_data/anoph/hap_frq.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def haplotypes_frequencies_advanced(
153153
ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default,
154154
chunks: base_params.chunks = base_params.native_chunks,
155155
inline_array: base_params.inline_array = base_params.inline_array_default,
156+
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
156157
) -> xr.Dataset:
157158
# Load sample metadata.
158159
df_samples = self.sample_metadata(
@@ -166,15 +167,17 @@ def haplotypes_frequencies_advanced(
166167
df_samples=df_samples,
167168
area_by=area_by,
168169
period_by=period_by,
170+
taxon_by=taxon_by,
169171
)
170172

171173
# Group samples to make cohorts.
172-
group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"])
174+
group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"])
173175

174176
# Build cohorts dataframe.
175177
df_cohorts = build_cohorts_from_sample_grouping(
176178
group_samples_by_cohort=group_samples_by_cohort,
177179
min_cohort_size=min_cohort_size,
180+
taxon_by=taxon_by,
178181
)
179182

180183
# Access haplotypes.
@@ -211,8 +214,9 @@ def haplotypes_frequencies_advanced(
211214
df_cohorts.itertuples(), desc="Compute allele frequencies"
212215
)
213216
for cohort in cohorts_iterator:
214-
cohort_key = cohort.taxon, cohort.area, cohort.period
215-
cohort_key_str = cohort.taxon + "_" + cohort.area + "_" + str(cohort.period)
217+
cohort_taxon = getattr(cohort, taxon_by)
218+
cohort_key = cohort_taxon, cohort.area, cohort.period
219+
cohort_key_str = cohort_taxon + "_" + cohort.area + "_" + str(cohort.period)
216220
# We reset all frequencies, counts to 0 for each cohort, nobs is set to the number of haplotypes
217221
n_samples = cohort.size
218222
hap_freq = {k: 0 for k in f_all.keys()}

malariagen_data/anoph/snp_frq.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def snp_allele_frequencies_advanced(
452452
ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default,
453453
chunks: base_params.chunks = base_params.native_chunks,
454454
inline_array: base_params.inline_array = base_params.inline_array_default,
455+
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
455456
) -> xr.Dataset:
456457
# Load sample metadata.
457458
df_samples = self.sample_metadata(
@@ -465,15 +466,17 @@ def snp_allele_frequencies_advanced(
465466
df_samples=df_samples,
466467
area_by=area_by,
467468
period_by=period_by,
469+
taxon_by=taxon_by,
468470
)
469471

470472
# Group samples to make cohorts.
471-
group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"])
473+
group_samples_by_cohort = df_samples.groupby([taxon_by, "area", "period"])
472474

473475
# Build cohorts dataframe.
474476
df_cohorts = build_cohorts_from_sample_grouping(
475477
group_samples_by_cohort=group_samples_by_cohort,
476478
min_cohort_size=min_cohort_size,
479+
taxon_by=taxon_by,
477480
)
478481

479482
# Early check for no cohorts.
@@ -529,7 +532,8 @@ def snp_allele_frequencies_advanced(
529532
desc="Compute SNP allele frequencies",
530533
)
531534
for cohort_index, cohort in cohorts_iterator:
532-
cohort_key = cohort.taxon, cohort.area, cohort.period
535+
cohort_taxon = getattr(cohort, taxon_by)
536+
cohort_key = cohort_taxon, cohort.area, cohort.period
533537
sample_indices = group_samples_by_cohort.indices[cohort_key]
534538

535539
cohort_ac, cohort_an = _cohort_alt_allele_counts_melt(
@@ -601,7 +605,11 @@ def snp_allele_frequencies_advanced(
601605

602606
# Cohort variables.
603607
for coh_col in df_cohorts.columns:
604-
ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col]
608+
if coh_col == taxon_by:
609+
# Other functions expect cohort_taxon, e.g. plot_frequencies_interactive_map()
610+
ds_out["cohort_taxon"] = "cohorts", df_cohorts[coh_col]
611+
else:
612+
ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col]
605613

606614
# Variant variables.
607615
for snp_col in df_variants.columns:
@@ -673,6 +681,7 @@ def aa_allele_frequencies_advanced(
673681
ci_method: Optional[frq_params.ci_method] = "wilson",
674682
chunks: base_params.chunks = base_params.native_chunks,
675683
inline_array: base_params.inline_array = base_params.inline_array_default,
684+
taxon_by: frq_params.taxon_by = frq_params.taxon_by_default,
676685
) -> xr.Dataset:
677686
# Begin by computing SNP allele frequencies.
678687
ds_snp_frq = self.snp_allele_frequencies_advanced(
@@ -690,6 +699,7 @@ def aa_allele_frequencies_advanced(
690699
ci_method=None, # we will recompute confidence intervals later
691700
chunks=chunks,
692701
inline_array=inline_array,
702+
taxon_by=taxon_by,
693703
)
694704

695705
# N.B., we need to worry about the possibility of the

0 commit comments

Comments
 (0)