Skip to content

Commit 0288ac8

Browse files
committed
Merge branch 'fix-colab-tpu-runtime' of https://github.com/joshitha1808/malariagen-data-python into fix-colab-tpu-runtime
2 parents e361716 + 7bdea7f commit 0288ac8

30 files changed

Lines changed: 1605 additions & 389 deletions

.github/actions/setup-python/action.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ runs:
1919
shell: bash
2020
run: |
2121
poetry env use ${{ inputs.python-version }}
22-
poetry install --extras dev
22+
poetry install --with dev,test,docs

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__pycache__
44
*.pyc
55
dist
6+
.venv/
67
.coverage
78
coverage.xml
89
.ipynb_checkpoints/

CONTRIBUTING.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This package provides Python tools for accessing and analyzing genomic data from
1212

1313
You'll need:
1414

15-
- [pipx](https://python-poetry.org/) for installing Python tools
15+
- [pipx](https://pipx.pypa.io/) for installing Python tools
1616
- [git](https://git-scm.com/) for version control
1717

1818
Both of these can be installed using your distribution's package manager or [Homebrew](https://brew.sh/) on Mac.
@@ -52,9 +52,13 @@ Both of these can be installed using your distribution's package manager or [Hom
5252

5353
```bash
5454
poetry env use 3.12
55-
poetry install --extras dev
55+
poetry install --with dev,test,docs
5656
```
5757

58+
This installs the runtime dependencies along with the `dev`, `test`, and `docs`
59+
[dependency groups](https://python-poetry.org/docs/managing-dependencies/#dependency-groups).
60+
If you only need to run tests, `poetry install --with test` is sufficient.
61+
5862
**Recommended**: Use `poetry run` to run commands inside the virtual environment:
5963

6064
```bash

malariagen_data/anoph/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,9 @@ def _read_sample_sets_manifest(self, *, single_release: str):
607607
# Get today's date in ISO format
608608
today_date_iso = date.today().isoformat()
609609
# Add an "unrestricted_use" column, set to True if terms-of-use expiry date <= today's date.
610-
df["unrestricted_use"] = df[terms_of_use_expiry_date_column].apply(
611-
lambda d: True if pd.isna(d) else (d <= today_date_iso)
612-
)
610+
# Vectorized operation: True if NaN, else (d <= today_date_iso)
611+
s = df[terms_of_use_expiry_date_column]
612+
df["unrestricted_use"] = s.isna() | (s <= today_date_iso)
613613
# Make the "unrestricted_use" column a nullable boolean, to allow missing data.
614614
df["unrestricted_use"] = df["unrestricted_use"].astype(pd.BooleanDtype())
615615

malariagen_data/anoph/base_params.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@
6969
str,
7070
"""
7171
A pandas query string to be evaluated against the sample metadata, to
72-
select samples to be included in the returned data.
72+
select samples to be included in the returned data. E.g.,
73+
"country == 'Uganda'". If the query returns zero results, a warning
74+
will be emitted with fuzzy-match suggestions for possible typos or
75+
case mismatches.
7376
""",
7477
]
7578

malariagen_data/anoph/dipclust.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Optional, Tuple
23

34
import allel # type: ignore
@@ -540,8 +541,9 @@ def _insert_dipclust_snp_trace(
540541
figures.append(snp_trace)
541542
subplot_heights.append(snp_row_height * n_snps_transcript)
542543
else:
543-
print(
544-
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
544+
warnings.warn(
545+
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot.",
546+
stacklevel=2,
545547
)
546548
return figures, subplot_heights, n_snps_transcript
547549

@@ -607,8 +609,9 @@ def plot_diplotype_clustering_advanced(
607609
cnv_colorscale = cnv_params.colorscale_default
608610
if cohort_size and snp_transcript:
609611
cohort_size = None
610-
print(
611-
"Cohort size is not supported with amino acid heatmap. Overriding cohort size to None."
612+
warnings.warn(
613+
"Cohort size is not supported with amino acid heatmap. Overriding cohort size to None.",
614+
stacklevel=2,
612615
)
613616

614617
res = self.plot_diplotype_clustering(

malariagen_data/anoph/frq_base.py

Lines changed: 107 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from textwrap import dedent
32
from typing import Optional, Union, List
43

@@ -29,6 +28,13 @@ def _prep_samples_for_cohort_grouping(
2928
# Users can explicitly override with True/False.
3029
filter_unassigned = taxon_by == "taxon"
3130

31+
# Validate taxon_by.
32+
if taxon_by not in df_samples.columns:
33+
raise ValueError(
34+
f"Invalid value for `taxon_by`: {taxon_by!r}. "
35+
f"Must be the name of an existing column in the sample metadata."
36+
)
37+
3238
if filter_unassigned:
3339
# Remove samples with "intermediate" or "unassigned" taxon values,
3440
# as we only want cohorts with clean taxon calls.
@@ -43,40 +49,46 @@ def _prep_samples_for_cohort_grouping(
4349

4450
# Add period column.
4551

46-
# Map supported period_by values to functions that return either the relevant pd.Period or pd.NaT per row.
47-
period_by_funcs = {
48-
"year": _make_sample_period_year,
49-
"quarter": _make_sample_period_quarter,
50-
"month": _make_sample_period_month,
52+
# Map supported period_by values to vectorized functions that create Period arrays.
53+
period_by_funcs_vectorized = {
54+
"year": _make_sample_periods_year_vectorized,
55+
"quarter": _make_sample_periods_quarter_vectorized,
56+
"month": _make_sample_periods_month_vectorized,
5157
}
5258

5359
# Get the matching function for the specified period_by value, or None.
54-
period_by_func = period_by_funcs.get(period_by)
60+
period_by_func_vectorized = period_by_funcs_vectorized.get(period_by)
5561

5662
# If there were no matching functions for the specified period_by value...
57-
if period_by_func is None:
63+
if period_by_func_vectorized is None:
5864
# Raise a ValueError if the specified period_by value is not a column in the DataFrame.
5965
if period_by not in df_samples.columns:
6066
raise ValueError(
6167
f"Invalid value for `period_by`: {period_by!r}. Either specify the name of an existing column "
6268
"or a supported period: 'year', 'quarter', or 'month'."
6369
)
6470

65-
# Raise a ValueError if the specified period_by column does not contain instances pd.Period.
66-
if (
67-
not df_samples[period_by]
68-
.apply(lambda value: pd.isnull(value) or isinstance(value, pd.Period))
69-
.all()
70-
):
71-
raise TypeError(
72-
f"Invalid values in {period_by!r} column. Must be either pandas.Period or null."
73-
)
71+
# Validate the specified period_by column contains pandas Periods (or nulls).
72+
s_period_by = df_samples[period_by]
73+
if not pd.api.types.is_period_dtype(s_period_by.dtype):
74+
non_null = s_period_by.dropna()
75+
if len(non_null) > 0 and not non_null.map(type).eq(pd.Period).all():
76+
raise TypeError(
77+
f"Invalid values in {period_by!r} column. Must be either pandas.Period or null."
78+
)
7479

7580
# Copy the specified period_by column to a new "period" column.
7681
df_samples["period"] = df_samples[period_by]
7782
else:
78-
# Apply the matching period_by function to create a new "period" column.
79-
df_samples["period"] = df_samples.apply(period_by_func, axis="columns")
83+
# Use the vectorized period creation function.
84+
df_samples["period"] = period_by_func_vectorized(df_samples)
85+
86+
# Validate area_by.
87+
if area_by not in df_samples.columns:
88+
raise ValueError(
89+
f"Invalid value for `area_by`: {area_by!r}. "
90+
f"Must be the name of an existing column in the sample metadata."
91+
)
8092

8193
# Copy the specified area_by column to a new "area" column.
8294
df_samples["area"] = df_samples[area_by]
@@ -101,22 +113,39 @@ def _build_cohorts_from_sample_grouping(
101113
df_cohorts = df_cohorts.reset_index()
102114

103115
# Add cohort helper variables.
104-
cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time)
105-
cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time)
106-
df_cohorts["period_start"] = cohort_period_start
107-
df_cohorts["period_end"] = cohort_period_end
116+
# Vectorized extraction of period start/end times.
117+
period = df_cohorts["period"]
118+
if pd.api.types.is_period_dtype(period.dtype):
119+
df_cohorts["period_start"] = period.dt.start_time
120+
df_cohorts["period_end"] = period.dt.end_time
121+
else:
122+
# Fallback for object dtype Period values.
123+
df_cohorts["period_start"] = period.map(
124+
lambda v: v.start_time if pd.notna(v) else pd.NaT
125+
)
126+
df_cohorts["period_end"] = period.map(
127+
lambda v: v.end_time if pd.notna(v) else pd.NaT
128+
)
129+
108130
# Create a label that is similar to the cohort metadata,
109131
# although this won't be perfect.
132+
# Vectorized string operations
110133
if taxon_by == frq_params.taxon_by_default:
111-
df_cohorts["label"] = df_cohorts.apply(
112-
lambda v: f"{v.area}_{v[taxon_by][:4]}_{v.period}", axis="columns"
113-
)
134+
# Default case: area_taxon_short_period
135+
area_str = df_cohorts["area"].astype(str)
136+
taxon_short = df_cohorts[taxon_by].astype(str).str.slice(0, 4)
137+
period_str = df_cohorts["period"].astype(str)
138+
df_cohorts["label"] = area_str + "_" + taxon_short + "_" + period_str
114139
else:
115-
# Replace non-alphanumeric characters in the taxon with underscores.
116-
df_cohorts["label"] = df_cohorts.apply(
117-
lambda v: f"{v.area}_{re.sub(r'[^A-Za-z0-9]+', '_', str(v[taxon_by]))}_{v.period}",
118-
axis="columns",
140+
# Non-default case: replace non-alphanumeric characters with underscores
141+
area_str = df_cohorts["area"].astype(str)
142+
taxon_clean = (
143+
df_cohorts[taxon_by]
144+
.astype(str)
145+
.str.replace(r"[^A-Za-z0-9]+", "_", regex=True)
119146
)
147+
period_str = df_cohorts["period"].astype(str)
148+
df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str
120149

121150
# Apply minimum cohort size.
122151
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
@@ -173,6 +202,50 @@ def _make_sample_period_year(row):
173202
return pd.NaT
174203

175204

205+
def _make_sample_periods_month_vectorized(df_samples):
206+
year = df_samples["year"]
207+
month = df_samples["month"]
208+
valid = (year > 0) & (month > 0)
209+
210+
out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[M]")
211+
if valid.any():
212+
out.loc[valid] = pd.PeriodIndex.from_fields(
213+
year=year.loc[valid].to_numpy(),
214+
month=month.loc[valid].to_numpy(),
215+
freq="M",
216+
)
217+
return out
218+
219+
220+
def _make_sample_periods_quarter_vectorized(df_samples):
221+
year = df_samples["year"]
222+
month = df_samples["month"]
223+
valid = (year > 0) & (month > 0)
224+
225+
out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[Q-DEC]")
226+
if valid.any():
227+
out.loc[valid] = pd.PeriodIndex.from_fields(
228+
year=year.loc[valid].to_numpy(),
229+
month=month.loc[valid].to_numpy(),
230+
freq="Q-DEC",
231+
)
232+
return out
233+
234+
235+
def _make_sample_periods_year_vectorized(df_samples):
236+
year = df_samples["year"]
237+
valid = year > 0
238+
239+
out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[Y-DEC]")
240+
if valid.any():
241+
out.loc[valid] = pd.PeriodIndex.from_fields(
242+
year=year.loc[valid].to_numpy(),
243+
month=np.full(int(valid.sum()), 12, dtype="int64"),
244+
freq="Y-DEC",
245+
)
246+
return out
247+
248+
176249
class AnophelesFrequencyAnalysis(AnophelesBase):
177250
def __init__(
178251
self,
@@ -263,14 +336,10 @@ def plot_frequencies_heatmap(
263336
index = list(index_names_as_list)
264337
df = df.reset_index().copy()
265338
if isinstance(index, list):
266-
index_col = (
267-
df[index]
268-
.astype(str)
269-
.apply(
270-
lambda row: ", ".join([o for o in row if o is not None]),
271-
axis="columns",
272-
)
273-
)
339+
idx_vals = df[index].astype(str).to_numpy()
340+
index_col = pd.Series(idx_vals[:, 0], index=df.index)
341+
for j in range(1, idx_vals.shape[1]):
342+
index_col = index_col + ", " + idx_vals[:, j]
274343
else:
275344
assert isinstance(index, str)
276345
index_col = df[index].astype(str)

malariagen_data/anoph/genome_features.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -446,29 +446,26 @@ def plot_genes(
446446

447447
# Put gene pointers (▲ or ▼) in a new column, depending on the strand.
448448
# Except if the gene_label is null or an empty string, which should not be shown.
449-
data["gene_pointer"] = data.apply(
450-
lambda row: ("▼" if row["strand"] == "+" else "▲")
451-
if row["gene_label"]
452-
else "",
453-
axis=1,
449+
data["gene_pointer"] = np.where(
450+
data["gene_label"] == "",
451+
"",
452+
np.where(data["strand"] == "+", "▼", "▲"),
454453
)
455454

456455
# Put the pointer above or below the gene rectangle, depending on + or - strand.
457456
neg_strand_pointer_y = orig_mid_y_range - 1.1
458457
pos_strand_pointer_y = orig_mid_y_range + 1.1
459-
data["pointer_y"] = data["strand"].apply(
460-
lambda strand: pos_strand_pointer_y
461-
if strand == "+"
462-
else neg_strand_pointer_y
458+
# Vectorized operation: use np.where instead of Series.apply
459+
data["pointer_y"] = np.where(
460+
data["strand"] == "+", pos_strand_pointer_y, neg_strand_pointer_y
463461
)
464462

465463
# Put the label above or below the gene rectangle, depending on + or - strand.
466464
neg_strand_label_y = orig_mid_y_range - 1.25
467465
pos_strand_label_y = orig_mid_y_range + 1.3
468-
data["label_y"] = data["strand"].apply(
469-
lambda strand: pos_strand_label_y
470-
if strand == "+"
471-
else neg_strand_label_y
466+
# Vectorized operation: use np.where instead of Series.apply
467+
data["label_y"] = np.where(
468+
data["strand"] == "+", pos_strand_label_y, neg_strand_label_y
472469
)
473470

474471
# Get the data as a ColumnDataSource.

0 commit comments

Comments
 (0)