1- import re
21from textwrap import dedent
32from typing import Optional , Union , List
43
@@ -29,6 +28,13 @@ def _prep_samples_for_cohort_grouping(
2928 # Users can explicitly override with True/False.
3029 filter_unassigned = taxon_by == "taxon"
3130
31+ # Validate taxon_by.
32+ if taxon_by not in df_samples .columns :
33+ raise ValueError (
34+ f"Invalid value for `taxon_by`: { taxon_by !r} . "
35+ f"Must be the name of an existing column in the sample metadata."
36+ )
37+
3238 if filter_unassigned :
3339 # Remove samples with "intermediate" or "unassigned" taxon values,
3440 # as we only want cohorts with clean taxon calls.
@@ -43,40 +49,46 @@ def _prep_samples_for_cohort_grouping(
4349
4450 # Add period column.
4551
46- # Map supported period_by values to functions that return either the relevant pd. Period or pd.NaT per row .
47- period_by_funcs = {
48- "year" : _make_sample_period_year ,
49- "quarter" : _make_sample_period_quarter ,
50- "month" : _make_sample_period_month ,
52+ # Map supported period_by values to vectorized functions that create Period arrays .
53+ period_by_funcs_vectorized = {
54+ "year" : _make_sample_periods_year_vectorized ,
55+ "quarter" : _make_sample_periods_quarter_vectorized ,
56+ "month" : _make_sample_periods_month_vectorized ,
5157 }
5258
5359 # Get the matching function for the specified period_by value, or None.
54- period_by_func = period_by_funcs .get (period_by )
60+ period_by_func_vectorized = period_by_funcs_vectorized .get (period_by )
5561
5662 # If there were no matching functions for the specified period_by value...
57- if period_by_func is None :
63+ if period_by_func_vectorized is None :
5864 # Raise a ValueError if the specified period_by value is not a column in the DataFrame.
5965 if period_by not in df_samples .columns :
6066 raise ValueError (
6167 f"Invalid value for `period_by`: { period_by !r} . Either specify the name of an existing column "
6268 "or a supported period: 'year', 'quarter', or 'month'."
6369 )
6470
65- # Raise a ValueError if the specified period_by column does not contain instances pd.Period.
66- if (
67- not df_samples [period_by ]
68- .apply (lambda value : pd .isnull (value ) or isinstance (value , pd .Period ))
69- .all ()
70- ):
71- raise TypeError (
72- f"Invalid values in { period_by !r} column. Must be either pandas.Period or null."
73- )
71+ # Validate the specified period_by column contains pandas Periods (or nulls).
72+ s_period_by = df_samples [period_by ]
73+ if not pd .api .types .is_period_dtype (s_period_by .dtype ):
74+ non_null = s_period_by .dropna ()
75+ if len (non_null ) > 0 and not non_null .map (type ).eq (pd .Period ).all ():
76+ raise TypeError (
77+ f"Invalid values in { period_by !r} column. Must be either pandas.Period or null."
78+ )
7479
7580 # Copy the specified period_by column to a new "period" column.
7681 df_samples ["period" ] = df_samples [period_by ]
7782 else :
78- # Apply the matching period_by function to create a new "period" column.
79- df_samples ["period" ] = df_samples .apply (period_by_func , axis = "columns" )
83+ # Use the vectorized period creation function.
84+ df_samples ["period" ] = period_by_func_vectorized (df_samples )
85+
86+ # Validate area_by.
87+ if area_by not in df_samples .columns :
88+ raise ValueError (
89+ f"Invalid value for `area_by`: { area_by !r} . "
90+ f"Must be the name of an existing column in the sample metadata."
91+ )
8092
8193 # Copy the specified area_by column to a new "area" column.
8294 df_samples ["area" ] = df_samples [area_by ]
@@ -101,22 +113,39 @@ def _build_cohorts_from_sample_grouping(
101113 df_cohorts = df_cohorts .reset_index ()
102114
103115 # Add cohort helper variables.
104- cohort_period_start = df_cohorts ["period" ].apply (lambda v : v .start_time )
105- cohort_period_end = df_cohorts ["period" ].apply (lambda v : v .end_time )
106- df_cohorts ["period_start" ] = cohort_period_start
107- df_cohorts ["period_end" ] = cohort_period_end
116+ # Vectorized extraction of period start/end times.
117+ period = df_cohorts ["period" ]
118+ if pd .api .types .is_period_dtype (period .dtype ):
119+ df_cohorts ["period_start" ] = period .dt .start_time
120+ df_cohorts ["period_end" ] = period .dt .end_time
121+ else :
122+ # Fallback for object dtype Period values.
123+ df_cohorts ["period_start" ] = period .map (
124+ lambda v : v .start_time if pd .notna (v ) else pd .NaT
125+ )
126+ df_cohorts ["period_end" ] = period .map (
127+ lambda v : v .end_time if pd .notna (v ) else pd .NaT
128+ )
129+
108130 # Create a label that is similar to the cohort metadata,
109131 # although this won't be perfect.
132+ # Vectorized string operations
110133 if taxon_by == frq_params .taxon_by_default :
111- df_cohorts ["label" ] = df_cohorts .apply (
112- lambda v : f"{ v .area } _{ v [taxon_by ][:4 ]} _{ v .period } " , axis = "columns"
113- )
134+ # Default case: area_taxon_short_period
135+ area_str = df_cohorts ["area" ].astype (str )
136+ taxon_short = df_cohorts [taxon_by ].astype (str ).str .slice (0 , 4 )
137+ period_str = df_cohorts ["period" ].astype (str )
138+ df_cohorts ["label" ] = area_str + "_" + taxon_short + "_" + period_str
114139 else :
115- # Replace non-alphanumeric characters in the taxon with underscores.
116- df_cohorts ["label" ] = df_cohorts .apply (
117- lambda v : f"{ v .area } _{ re .sub (r'[^A-Za-z0-9]+' , '_' , str (v [taxon_by ]))} _{ v .period } " ,
118- axis = "columns" ,
140+ # Non-default case: replace non-alphanumeric characters with underscores
141+ area_str = df_cohorts ["area" ].astype (str )
142+ taxon_clean = (
143+ df_cohorts [taxon_by ]
144+ .astype (str )
145+ .str .replace (r"[^A-Za-z0-9]+" , "_" , regex = True )
119146 )
147+ period_str = df_cohorts ["period" ].astype (str )
148+ df_cohorts ["label" ] = area_str + "_" + taxon_clean + "_" + period_str
120149
121150 # Apply minimum cohort size.
122151 df_cohorts = df_cohorts .query (f"size >= { min_cohort_size } " ).reset_index (drop = True )
@@ -173,6 +202,50 @@ def _make_sample_period_year(row):
173202 return pd .NaT
174203
175204
205+ def _make_sample_periods_month_vectorized (df_samples ):
206+ year = df_samples ["year" ]
207+ month = df_samples ["month" ]
208+ valid = (year > 0 ) & (month > 0 )
209+
210+ out = pd .Series (pd .NaT , index = df_samples .index , dtype = "period[M]" )
211+ if valid .any ():
212+ out .loc [valid ] = pd .PeriodIndex .from_fields (
213+ year = year .loc [valid ].to_numpy (),
214+ month = month .loc [valid ].to_numpy (),
215+ freq = "M" ,
216+ )
217+ return out
218+
219+
220+ def _make_sample_periods_quarter_vectorized (df_samples ):
221+ year = df_samples ["year" ]
222+ month = df_samples ["month" ]
223+ valid = (year > 0 ) & (month > 0 )
224+
225+ out = pd .Series (pd .NaT , index = df_samples .index , dtype = "period[Q-DEC]" )
226+ if valid .any ():
227+ out .loc [valid ] = pd .PeriodIndex .from_fields (
228+ year = year .loc [valid ].to_numpy (),
229+ month = month .loc [valid ].to_numpy (),
230+ freq = "Q-DEC" ,
231+ )
232+ return out
233+
234+
235+ def _make_sample_periods_year_vectorized (df_samples ):
236+ year = df_samples ["year" ]
237+ valid = year > 0
238+
239+ out = pd .Series (pd .NaT , index = df_samples .index , dtype = "period[Y-DEC]" )
240+ if valid .any ():
241+ out .loc [valid ] = pd .PeriodIndex .from_fields (
242+ year = year .loc [valid ].to_numpy (),
243+ month = np .full (int (valid .sum ()), 12 , dtype = "int64" ),
244+ freq = "Y-DEC" ,
245+ )
246+ return out
247+
248+
176249class AnophelesFrequencyAnalysis (AnophelesBase ):
177250 def __init__ (
178251 self ,
@@ -263,14 +336,10 @@ def plot_frequencies_heatmap(
263336 index = list (index_names_as_list )
264337 df = df .reset_index ().copy ()
265338 if isinstance (index , list ):
266- index_col = (
267- df [index ]
268- .astype (str )
269- .apply (
270- lambda row : ", " .join ([o for o in row if o is not None ]),
271- axis = "columns" ,
272- )
273- )
339+ idx_vals = df [index ].astype (str ).to_numpy ()
340+ index_col = pd .Series (idx_vals [:, 0 ], index = df .index )
341+ for j in range (1 , idx_vals .shape [1 ]):
342+ index_col = index_col + ", " + idx_vals [:, j ]
274343 else :
275344 assert isinstance (index , str )
276345 index_col = df [index ].astype (str )
0 commit comments