|
1 | | -import re |
2 | 1 | from textwrap import dedent |
3 | 2 | from typing import Optional, Union, List |
4 | 3 |
|
@@ -50,40 +49,39 @@ def _prep_samples_for_cohort_grouping( |
50 | 49 |
|
51 | 50 | # Add period column. |
52 | 51 |
|
53 | | - # Map supported period_by values to functions that return either the relevant pd.Period or pd.NaT per row. |
54 | | - period_by_funcs = { |
55 | | - "year": _make_sample_period_year, |
56 | | - "quarter": _make_sample_period_quarter, |
57 | | - "month": _make_sample_period_month, |
| 52 | + # Map supported period_by values to vectorized functions that create Period arrays. |
| 53 | + period_by_funcs_vectorized = { |
| 54 | + "year": _make_sample_periods_year_vectorized, |
| 55 | + "quarter": _make_sample_periods_quarter_vectorized, |
| 56 | + "month": _make_sample_periods_month_vectorized, |
58 | 57 | } |
59 | 58 |
|
60 | 59 | # Get the matching function for the specified period_by value, or None. |
61 | | - period_by_func = period_by_funcs.get(period_by) |
| 60 | + period_by_func_vectorized = period_by_funcs_vectorized.get(period_by) |
62 | 61 |
|
63 | 62 | # If there were no matching functions for the specified period_by value... |
64 | | - if period_by_func is None: |
| 63 | + if period_by_func_vectorized is None: |
65 | 64 | # Raise a ValueError if the specified period_by value is not a column in the DataFrame. |
66 | 65 | if period_by not in df_samples.columns: |
67 | 66 | raise ValueError( |
68 | 67 | f"Invalid value for `period_by`: {period_by!r}. Either specify the name of an existing column " |
69 | 68 | "or a supported period: 'year', 'quarter', or 'month'." |
70 | 69 | ) |
71 | 70 |
|
72 | | - # Raise a ValueError if the specified period_by column does not contain instances pd.Period. |
73 | | - if ( |
74 | | - not df_samples[period_by] |
75 | | - .apply(lambda value: pd.isnull(value) or isinstance(value, pd.Period)) |
76 | | - .all() |
77 | | - ): |
78 | | - raise TypeError( |
79 | | - f"Invalid values in {period_by!r} column. Must be either pandas.Period or null." |
80 | | - ) |
| 71 | + # Validate the specified period_by column contains pandas Periods (or nulls). |
| 72 | + s_period_by = df_samples[period_by] |
| 73 | + if not pd.api.types.is_period_dtype(s_period_by.dtype): |
| 74 | + non_null = s_period_by.dropna() |
| 75 | + if len(non_null) > 0 and not non_null.map(type).eq(pd.Period).all(): |
| 76 | + raise TypeError( |
| 77 | + f"Invalid values in {period_by!r} column. Must be either pandas.Period or null." |
| 78 | + ) |
81 | 79 |
|
82 | 80 | # Copy the specified period_by column to a new "period" column. |
83 | 81 | df_samples["period"] = df_samples[period_by] |
84 | 82 | else: |
85 | | - # Apply the matching period_by function to create a new "period" column. |
86 | | - df_samples["period"] = df_samples.apply(period_by_func, axis="columns") |
| 83 | + # Use the vectorized period creation function. |
| 84 | + df_samples["period"] = period_by_func_vectorized(df_samples) |
87 | 85 |
|
88 | 86 | # Validate area_by. |
89 | 87 | if area_by not in df_samples.columns: |
@@ -115,22 +113,39 @@ def _build_cohorts_from_sample_grouping( |
115 | 113 | df_cohorts = df_cohorts.reset_index() |
116 | 114 |
|
117 | 115 | # Add cohort helper variables. |
118 | | - cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time) |
119 | | - cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time) |
120 | | - df_cohorts["period_start"] = cohort_period_start |
121 | | - df_cohorts["period_end"] = cohort_period_end |
| 116 | + # Vectorized extraction of period start/end times. |
| 117 | + period = df_cohorts["period"] |
| 118 | + if pd.api.types.is_period_dtype(period.dtype): |
| 119 | + df_cohorts["period_start"] = period.dt.start_time |
| 120 | + df_cohorts["period_end"] = period.dt.end_time |
| 121 | + else: |
| 122 | + # Fallback for object dtype Period values. |
| 123 | + df_cohorts["period_start"] = period.map( |
| 124 | + lambda v: v.start_time if pd.notna(v) else pd.NaT |
| 125 | + ) |
| 126 | + df_cohorts["period_end"] = period.map( |
| 127 | + lambda v: v.end_time if pd.notna(v) else pd.NaT |
| 128 | + ) |
| 129 | + |
122 | 130 | # Create a label that is similar to the cohort metadata, |
123 | 131 | # although this won't be perfect. |
| 132 | + # Vectorized string operations |
124 | 133 | if taxon_by == frq_params.taxon_by_default: |
125 | | - df_cohorts["label"] = df_cohorts.apply( |
126 | | - lambda v: f"{v.area}_{v[taxon_by][:4]}_{v.period}", axis="columns" |
127 | | - ) |
| 134 | + # Default case: area_taxon_short_period |
| 135 | + area_str = df_cohorts["area"].astype(str) |
| 136 | + taxon_short = df_cohorts[taxon_by].astype(str).str.slice(0, 4) |
| 137 | + period_str = df_cohorts["period"].astype(str) |
| 138 | + df_cohorts["label"] = area_str + "_" + taxon_short + "_" + period_str |
128 | 139 | else: |
129 | | - # Replace non-alphanumeric characters in the taxon with underscores. |
130 | | - df_cohorts["label"] = df_cohorts.apply( |
131 | | - lambda v: f"{v.area}_{re.sub(r'[^A-Za-z0-9]+', '_', str(v[taxon_by]))}_{v.period}", |
132 | | - axis="columns", |
| 140 | + # Non-default case: replace non-alphanumeric characters with underscores |
| 141 | + area_str = df_cohorts["area"].astype(str) |
| 142 | + taxon_clean = ( |
| 143 | + df_cohorts[taxon_by] |
| 144 | + .astype(str) |
| 145 | + .str.replace(r"[^A-Za-z0-9]+", "_", regex=True) |
133 | 146 | ) |
| 147 | + period_str = df_cohorts["period"].astype(str) |
| 148 | + df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str |
134 | 149 |
|
135 | 150 | # Apply minimum cohort size. |
136 | 151 | df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True) |
@@ -187,6 +202,50 @@ def _make_sample_period_year(row): |
187 | 202 | return pd.NaT |
188 | 203 |
|
189 | 204 |
|
| 205 | +def _make_sample_periods_month_vectorized(df_samples): |
| 206 | + year = df_samples["year"] |
| 207 | + month = df_samples["month"] |
| 208 | + valid = (year > 0) & (month > 0) |
| 209 | + |
| 210 | + out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[M]") |
| 211 | + if valid.any(): |
| 212 | + out.loc[valid] = pd.PeriodIndex.from_fields( |
| 213 | + year=year.loc[valid].to_numpy(), |
| 214 | + month=month.loc[valid].to_numpy(), |
| 215 | + freq="M", |
| 216 | + ) |
| 217 | + return out |
| 218 | + |
| 219 | + |
| 220 | +def _make_sample_periods_quarter_vectorized(df_samples): |
| 221 | + year = df_samples["year"] |
| 222 | + month = df_samples["month"] |
| 223 | + valid = (year > 0) & (month > 0) |
| 224 | + |
| 225 | + out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[Q-DEC]") |
| 226 | + if valid.any(): |
| 227 | + out.loc[valid] = pd.PeriodIndex.from_fields( |
| 228 | + year=year.loc[valid].to_numpy(), |
| 229 | + month=month.loc[valid].to_numpy(), |
| 230 | + freq="Q-DEC", |
| 231 | + ) |
| 232 | + return out |
| 233 | + |
| 234 | + |
| 235 | +def _make_sample_periods_year_vectorized(df_samples): |
| 236 | + year = df_samples["year"] |
| 237 | + valid = year > 0 |
| 238 | + |
| 239 | + out = pd.Series(pd.NaT, index=df_samples.index, dtype="period[Y-DEC]") |
| 240 | + if valid.any(): |
| 241 | + out.loc[valid] = pd.PeriodIndex.from_fields( |
| 242 | + year=year.loc[valid].to_numpy(), |
| 243 | + month=np.full(int(valid.sum()), 12, dtype="int64"), |
| 244 | + freq="Y-DEC", |
| 245 | + ) |
| 246 | + return out |
| 247 | + |
| 248 | + |
190 | 249 | class AnophelesFrequencyAnalysis(AnophelesBase): |
191 | 250 | def __init__( |
192 | 251 | self, |
@@ -277,14 +336,10 @@ def plot_frequencies_heatmap( |
277 | 336 | index = list(index_names_as_list) |
278 | 337 | df = df.reset_index().copy() |
279 | 338 | if isinstance(index, list): |
280 | | - index_col = ( |
281 | | - df[index] |
282 | | - .astype(str) |
283 | | - .apply( |
284 | | - lambda row: ", ".join([o for o in row if o is not None]), |
285 | | - axis="columns", |
286 | | - ) |
287 | | - ) |
| 339 | + idx_vals = df[index].astype(str).to_numpy() |
| 340 | + index_col = pd.Series(idx_vals[:, 0], index=df.index) |
| 341 | + for j in range(1, idx_vals.shape[1]): |
| 342 | + index_col = index_col + ", " + idx_vals[:, j] |
288 | 343 | else: |
289 | 344 | assert isinstance(index, str) |
290 | 345 | index_col = df[index].astype(str) |
|
0 commit comments