Skip to content

Commit 5d0e8b8

Browse files
committed
perf: vectorize metadata DataFrame operations
Replaces row-wise pandas apply() usage in metadata/cohort preparation with vectorized numpy/pandas operations to reduce Python-level iteration. Made-with: Cursor
1 parent c269768 commit 5d0e8b8

5 files changed

Lines changed: 107 additions & 58 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__pycache__
44
*.pyc
55
dist
6+
.venv/
67
.coverage
78
coverage.xml
89
.ipynb_checkpoints/

malariagen_data/anoph/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,9 @@ def _read_sample_sets_manifest(self, *, single_release: str):
607607
# Get today's date in ISO format
608608
today_date_iso = date.today().isoformat()
609609
# Add an "unrestricted_use" column, set to True if terms-of-use expiry date <= today's date.
610-
df["unrestricted_use"] = df[terms_of_use_expiry_date_column].apply(
611-
lambda d: True if pd.isna(d) else (d <= today_date_iso)
612-
)
610+
# Vectorized operation: True if NaN, else (d <= today_date_iso)
611+
s = df[terms_of_use_expiry_date_column]
612+
df["unrestricted_use"] = (s.isna() | (s <= today_date_iso))
613613
# Make the "unrestricted_use" column a nullable boolean, to allow missing data.
614614
df["unrestricted_use"] = df["unrestricted_use"].astype(pd.BooleanDtype())
615615

malariagen_data/anoph/frq_base.py

Lines changed: 90 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,40 +43,39 @@ def _prep_samples_for_cohort_grouping(
4343

4444
# Add period column.
4545

46-
# Map supported period_by values to functions that return either the relevant pd.Period or pd.NaT per row.
47-
period_by_funcs = {
48-
"year": _make_sample_period_year,
49-
"quarter": _make_sample_period_quarter,
50-
"month": _make_sample_period_month,
46+
# Map supported period_by values to vectorized functions that create Period arrays.
47+
period_by_funcs_vectorized = {
48+
"year": _make_sample_periods_year_vectorized,
49+
"quarter": _make_sample_periods_quarter_vectorized,
50+
"month": _make_sample_periods_month_vectorized,
5151
}
5252

5353
# Get the matching function for the specified period_by value, or None.
54-
period_by_func = period_by_funcs.get(period_by)
54+
period_by_func_vectorized = period_by_funcs_vectorized.get(period_by)
5555

5656
# If there were no matching functions for the specified period_by value...
57-
if period_by_func is None:
57+
if period_by_func_vectorized is None:
5858
# Raise a ValueError if the specified period_by value is not a column in the DataFrame.
5959
if period_by not in df_samples.columns:
6060
raise ValueError(
6161
f"Invalid value for `period_by`: {period_by!r}. Either specify the name of an existing column "
6262
"or a supported period: 'year', 'quarter', or 'month'."
6363
)
6464

65-
# Raise a ValueError if the specified period_by column does not contain instances pd.Period.
66-
if (
67-
not df_samples[period_by]
68-
.apply(lambda value: pd.isnull(value) or isinstance(value, pd.Period))
69-
.all()
70-
):
71-
raise TypeError(
72-
f"Invalid values in {period_by!r} column. Must be either pandas.Period or null."
73-
)
65+
# Validate the specified period_by column contains pandas Periods (or nulls).
66+
s_period_by = df_samples[period_by]
67+
if not pd.api.types.is_period_dtype(s_period_by.dtype):
68+
non_null = s_period_by.dropna()
69+
if len(non_null) > 0 and not non_null.map(type).eq(pd.Period).all():
70+
raise TypeError(
71+
f"Invalid values in {period_by!r} column. Must be either pandas.Period or null."
72+
)
7473

7574
# Copy the specified period_by column to a new "period" column.
7675
df_samples["period"] = df_samples[period_by]
7776
else:
78-
# Apply the matching period_by function to create a new "period" column.
79-
df_samples["period"] = df_samples.apply(period_by_func, axis="columns")
77+
# Use the vectorized period creation function.
78+
df_samples["period"] = period_by_func_vectorized(df_samples)
8079

8180
# Copy the specified area_by column to a new "area" column.
8281
df_samples["area"] = df_samples[area_by]
@@ -101,22 +100,35 @@ def _build_cohorts_from_sample_grouping(
101100
df_cohorts = df_cohorts.reset_index()
102101

103102
# Add cohort helper variables.
104-
cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time)
105-
cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time)
106-
df_cohorts["period_start"] = cohort_period_start
107-
df_cohorts["period_end"] = cohort_period_end
103+
# Vectorized extraction of period start/end times.
104+
period = df_cohorts["period"]
105+
if pd.api.types.is_period_dtype(period.dtype):
106+
df_cohorts["period_start"] = period.dt.start_time
107+
df_cohorts["period_end"] = period.dt.end_time
108+
else:
109+
# Fallback for object dtype Period values.
110+
df_cohorts["period_start"] = period.map(
111+
lambda v: v.start_time if pd.notna(v) else pd.NaT
112+
)
113+
df_cohorts["period_end"] = period.map(
114+
lambda v: v.end_time if pd.notna(v) else pd.NaT
115+
)
116+
108117
# Create a label that is similar to the cohort metadata,
109118
# although this won't be perfect.
119+
# Vectorized string operations
110120
if taxon_by == frq_params.taxon_by_default:
111-
df_cohorts["label"] = df_cohorts.apply(
112-
lambda v: f"{v.area}_{v[taxon_by][:4]}_{v.period}", axis="columns"
113-
)
121+
# Default case: area_taxon_short_period
122+
area_str = df_cohorts["area"].astype(str)
123+
taxon_short = df_cohorts[taxon_by].astype(str).str.slice(0, 4)
124+
period_str = df_cohorts["period"].astype(str)
125+
df_cohorts["label"] = area_str + "_" + taxon_short + "_" + period_str
114126
else:
115-
# Replace non-alphanumeric characters in the taxon with underscores.
116-
df_cohorts["label"] = df_cohorts.apply(
117-
lambda v: f"{v.area}_{re.sub(r'[^A-Za-z0-9]+', '_', str(v[taxon_by]))}_{v.period}",
118-
axis="columns",
119-
)
127+
# Non-default case: replace non-alphanumeric characters with underscores
128+
area_str = df_cohorts["area"].astype(str)
129+
taxon_clean = df_cohorts[taxon_by].astype(str).str.replace(r"[^A-Za-z0-9]+", "_", regex=True)
130+
period_str = df_cohorts["period"].astype(str)
131+
df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str
120132

121133
# Apply minimum cohort size.
122134
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
@@ -173,6 +185,50 @@ def _make_sample_period_year(row):
173185
return pd.NaT
174186

175187

188+
def _make_sample_periods_month_vectorized(df_samples):
189+
year = df_samples["year"]
190+
month = df_samples["month"]
191+
valid = (year > 0) & (month > 0)
192+
193+
out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[M]")
194+
if valid.any():
195+
out.loc[valid] = pd.PeriodIndex.from_fields(
196+
year=year.loc[valid].to_numpy(),
197+
month=month.loc[valid].to_numpy(),
198+
freq="M",
199+
)
200+
return out
201+
202+
203+
def _make_sample_periods_quarter_vectorized(df_samples):
204+
year = df_samples["year"]
205+
month = df_samples["month"]
206+
valid = (year > 0) & (month > 0)
207+
208+
out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[Q-DEC]")
209+
if valid.any():
210+
out.loc[valid] = pd.PeriodIndex.from_fields(
211+
year=year.loc[valid].to_numpy(),
212+
month=month.loc[valid].to_numpy(),
213+
freq="Q-DEC",
214+
)
215+
return out
216+
217+
218+
def _make_sample_periods_year_vectorized(df_samples):
219+
year = df_samples["year"]
220+
valid = year > 0
221+
222+
out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[Y-DEC]")
223+
if valid.any():
224+
out.loc[valid] = pd.PeriodIndex.from_fields(
225+
year=year.loc[valid].to_numpy(),
226+
month=np.full(int(valid.sum()), 12, dtype="int64"),
227+
freq="Y-DEC",
228+
)
229+
return out
230+
231+
176232
class AnophelesFrequencyAnalysis(AnophelesBase):
177233
def __init__(
178234
self,
@@ -263,14 +319,10 @@ def plot_frequencies_heatmap(
263319
index = list(index_names_as_list)
264320
df = df.reset_index().copy()
265321
if isinstance(index, list):
266-
index_col = (
267-
df[index]
268-
.astype(str)
269-
.apply(
270-
lambda row: ", ".join([o for o in row if o is not None]),
271-
axis="columns",
272-
)
273-
)
322+
idx_vals = df[index].astype(str).to_numpy()
323+
index_col = pd.Series(idx_vals[:, 0], index=df.index)
324+
for j in range(1, idx_vals.shape[1]):
325+
index_col = index_col + ", " + idx_vals[:, j]
274326
else:
275327
assert isinstance(index, str)
276328
index_col = df[index].astype(str)

malariagen_data/anoph/genome_features.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -446,29 +446,25 @@ def plot_genes(
446446

447447
# Put gene pointers (▲ or ▼) in a new column, depending on the strand.
448448
# Except if the gene_label is null or an empty string, which should not be shown.
449-
data["gene_pointer"] = data.apply(
450-
lambda row: ("▼" if row["strand"] == "+" else "▲")
451-
if row["gene_label"]
452-
else "",
453-
axis=1,
454-
)
449+
# Vectorized operation: use np.where for conditional logic
450+
has_label = data["gene_label"].astype(bool)
451+
pointer_symbol = np.where(data["strand"] == "+", "▼", "▲")
452+
data["gene_pointer"] = np.where(has_label, pointer_symbol, "")
455453

456454
# Put the pointer above or below the gene rectangle, depending on + or - strand.
457455
neg_strand_pointer_y = orig_mid_y_range - 1.1
458456
pos_strand_pointer_y = orig_mid_y_range + 1.1
459-
data["pointer_y"] = data["strand"].apply(
460-
lambda strand: pos_strand_pointer_y
461-
if strand == "+"
462-
else neg_strand_pointer_y
457+
# Vectorized operation: use np.where instead of Series.apply
458+
data["pointer_y"] = np.where(
459+
data["strand"] == "+", pos_strand_pointer_y, neg_strand_pointer_y
463460
)
464461

465462
# Put the label above or below the gene rectangle, depending on + or - strand.
466463
neg_strand_label_y = orig_mid_y_range - 1.25
467464
pos_strand_label_y = orig_mid_y_range + 1.3
468-
data["label_y"] = data["strand"].apply(
469-
lambda strand: pos_strand_label_y
470-
if strand == "+"
471-
else neg_strand_label_y
465+
# Vectorized operation: use np.where instead of Series.apply
466+
data["label_y"] = np.where(
467+
data["strand"] == "+", pos_strand_label_y, neg_strand_label_y
472468
)
473469

474470
# Get the data as a ColumnDataSource.

malariagen_data/anoph/sample_metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ def _parse_general_metadata(
181181
df["release"] = release
182182

183183
# Derive a quarter column from month.
184-
df["quarter"] = df.apply(
185-
lambda row: ((row.month - 1) // 3) + 1 if row.month > 0 else -1,
186-
axis="columns",
184+
# Vectorized operation: quarter = ((month - 1) // 3) + 1 if month > 0 else -1
185+
df["quarter"] = np.where(
186+
df["month"] > 0, ((df["month"] - 1) // 3) + 1, -1
187187
)
188188

189189
# Add study columns.

0 commit comments

Comments
 (0)