@@ -43,40 +43,39 @@ def _prep_samples_for_cohort_grouping(
4343
4444 # Add period column.
4545
46- # Map supported period_by values to functions that return either the relevant pd. Period or pd.NaT per row .
47- period_by_funcs = {
48- "year" : _make_sample_period_year ,
49- "quarter" : _make_sample_period_quarter ,
50- "month" : _make_sample_period_month ,
46+ # Map supported period_by values to vectorized functions that create Period arrays .
47+ period_by_funcs_vectorized = {
48+ "year" : _make_sample_periods_year_vectorized ,
49+ "quarter" : _make_sample_periods_quarter_vectorized ,
50+ "month" : _make_sample_periods_month_vectorized ,
5151 }
5252
5353 # Get the matching function for the specified period_by value, or None.
54- period_by_func = period_by_funcs .get (period_by )
54+ period_by_func_vectorized = period_by_funcs_vectorized .get (period_by )
5555
5656 # If there were no matching functions for the specified period_by value...
57- if period_by_func is None :
57+ if period_by_func_vectorized is None :
5858 # Raise a ValueError if the specified period_by value is not a column in the DataFrame.
5959 if period_by not in df_samples .columns :
6060 raise ValueError (
6161 f"Invalid value for `period_by`: { period_by !r} . Either specify the name of an existing column "
6262 "or a supported period: 'year', 'quarter', or 'month'."
6363 )
6464
65- # Raise a ValueError if the specified period_by column does not contain instances pd.Period.
66- if (
67- not df_samples [period_by ]
68- .apply (lambda value : pd .isnull (value ) or isinstance (value , pd .Period ))
69- .all ()
70- ):
71- raise TypeError (
72- f"Invalid values in { period_by !r} column. Must be either pandas.Period or null."
73- )
65+ # Validate the specified period_by column contains pandas Periods (or nulls).
66+ s_period_by = df_samples [period_by ]
67+ if not pd .api .types .is_period_dtype (s_period_by .dtype ):
68+ non_null = s_period_by .dropna ()
69+ if len (non_null ) > 0 and not non_null .map (type ).eq (pd .Period ).all ():
70+ raise TypeError (
71+ f"Invalid values in { period_by !r} column. Must be either pandas.Period or null."
72+ )
7473
7574 # Copy the specified period_by column to a new "period" column.
7675 df_samples ["period" ] = df_samples [period_by ]
7776 else :
78- # Apply the matching period_by function to create a new " period" column .
79- df_samples ["period" ] = df_samples . apply ( period_by_func , axis = "columns" )
77+ # Use the vectorized period creation function .
78+ df_samples ["period" ] = period_by_func_vectorized ( df_samples )
8079
8180 # Copy the specified area_by column to a new "area" column.
8281 df_samples ["area" ] = df_samples [area_by ]
@@ -101,22 +100,35 @@ def _build_cohorts_from_sample_grouping(
101100 df_cohorts = df_cohorts .reset_index ()
102101
103102 # Add cohort helper variables.
104- cohort_period_start = df_cohorts ["period" ].apply (lambda v : v .start_time )
105- cohort_period_end = df_cohorts ["period" ].apply (lambda v : v .end_time )
106- df_cohorts ["period_start" ] = cohort_period_start
107- df_cohorts ["period_end" ] = cohort_period_end
103+ # Vectorized extraction of period start/end times.
104+ period = df_cohorts ["period" ]
105+ if pd .api .types .is_period_dtype (period .dtype ):
106+ df_cohorts ["period_start" ] = period .dt .start_time
107+ df_cohorts ["period_end" ] = period .dt .end_time
108+ else :
109+ # Fallback for object dtype Period values.
110+ df_cohorts ["period_start" ] = period .map (
111+ lambda v : v .start_time if pd .notna (v ) else pd .NaT
112+ )
113+ df_cohorts ["period_end" ] = period .map (
114+ lambda v : v .end_time if pd .notna (v ) else pd .NaT
115+ )
116+
108117 # Create a label that is similar to the cohort metadata,
109118 # although this won't be perfect.
119+ # Vectorized string operations
110120 if taxon_by == frq_params .taxon_by_default :
111- df_cohorts ["label" ] = df_cohorts .apply (
112- lambda v : f"{ v .area } _{ v [taxon_by ][:4 ]} _{ v .period } " , axis = "columns"
113- )
121+ # Default case: area_taxon_short_period
122+ area_str = df_cohorts ["area" ].astype (str )
123+ taxon_short = df_cohorts [taxon_by ].astype (str ).str .slice (0 , 4 )
124+ period_str = df_cohorts ["period" ].astype (str )
125+ df_cohorts ["label" ] = area_str + "_" + taxon_short + "_" + period_str
114126 else :
115- # Replace non-alphanumeric characters in the taxon with underscores.
116- df_cohorts ["label" ] = df_cohorts . apply (
117- lambda v : f" { v . area } _ { re . sub ( r' [^A-Za-z0-9]+' , '_' , str ( v [ taxon_by ])) } _ { v . period } " ,
118- axis = "columns" ,
119- )
127+ # Non-default case: replace non-alphanumeric characters with underscores
128+ area_str = df_cohorts ["area" ]. astype ( str )
129+ taxon_clean = df_cohorts [ taxon_by ]. astype ( str ). str . replace ( r" [^A-Za-z0-9]+" , "_" , regex = True )
130+ period_str = df_cohorts [ "period" ]. astype ( str )
131+ df_cohorts [ "label" ] = area_str + "_" + taxon_clean + "_" + period_str
120132
121133 # Apply minimum cohort size.
122134 df_cohorts = df_cohorts .query (f"size >= { min_cohort_size } " ).reset_index (drop = True )
@@ -173,6 +185,50 @@ def _make_sample_period_year(row):
173185 return pd .NaT
174186
175187
188+ def _make_sample_periods_month_vectorized (df_samples ):
189+ year = df_samples ["year" ]
190+ month = df_samples ["month" ]
191+ valid = (year > 0 ) & (month > 0 )
192+
193+ out = pd .Series (pd .NaT , index = df_samples .index , dtype = "period[M]" )
194+ if valid .any ():
195+ out .loc [valid ] = pd .PeriodIndex .from_fields (
196+ year = year .loc [valid ].to_numpy (),
197+ month = month .loc [valid ].to_numpy (),
198+ freq = "M" ,
199+ )
200+ return out
201+
202+
203+ def _make_sample_periods_quarter_vectorized (df_samples ):
204+ year = df_samples ["year" ]
205+ month = df_samples ["month" ]
206+ valid = (year > 0 ) & (month > 0 )
207+
208+ out = pd .Series (pd .NaT , index = df_samples .index , dtype = "period[Q-DEC]" )
209+ if valid .any ():
210+ out .loc [valid ] = pd .PeriodIndex .from_fields (
211+ year = year .loc [valid ].to_numpy (),
212+ month = month .loc [valid ].to_numpy (),
213+ freq = "Q-DEC" ,
214+ )
215+ return out
216+
217+
218+ def _make_sample_periods_year_vectorized (df_samples ):
219+ year = df_samples ["year" ]
220+ valid = year > 0
221+
222+ out = pd .Series (pd .NaT , index = df_samples .index , dtype = "period[Y-DEC]" )
223+ if valid .any ():
224+ out .loc [valid ] = pd .PeriodIndex .from_fields (
225+ year = year .loc [valid ].to_numpy (),
226+ month = np .full (int (valid .sum ()), 12 , dtype = "int64" ),
227+ freq = "Y-DEC" ,
228+ )
229+ return out
230+
231+
176232class AnophelesFrequencyAnalysis (AnophelesBase ):
177233 def __init__ (
178234 self ,
@@ -263,14 +319,10 @@ def plot_frequencies_heatmap(
263319 index = list (index_names_as_list )
264320 df = df .reset_index ().copy ()
265321 if isinstance (index , list ):
266- index_col = (
267- df [index ]
268- .astype (str )
269- .apply (
270- lambda row : ", " .join ([o for o in row if o is not None ]),
271- axis = "columns" ,
272- )
273- )
322+ idx_vals = df [index ].astype (str ).to_numpy ()
323+ index_col = pd .Series (idx_vals [:, 0 ], index = df .index )
324+ for j in range (1 , idx_vals .shape [1 ]):
325+ index_col = index_col + ", " + idx_vals [:, j ]
274326 else :
275327 assert isinstance (index , str )
276328 index_col = df [index ].astype (str )
0 commit comments