@@ -246,23 +246,37 @@ def _pca(
246246 if max_missing_an is not None and max_missing_an > 0 :
247247 gn_fit = gn_fit .astype (float )
248248 gn = gn .astype (float )
249+
249250 for arr in [gn_fit , gn ]:
250251 missing_mask = arr == - 127
251252
252253 if imputation_method == "most_common" :
253- # For each site, find the most common non-missing value.
254- site_modes = []
255- for row in arr :
256- non_missing = row [row != - 127 ]
257- if len (non_missing ) == 0 :
258- site_modes .append (0 )
259- else :
260- values , counts = np .unique (
261- non_missing , return_counts = True
262- )
263- site_modes .append (values [np .argmax (counts )])
264- site_modes = np .array (site_modes )
265- fill_values = np .take (site_modes , np .where (missing_mask )[0 ])
254+ # Vectorized computation of mode per site (row)
255+ valid = ~ missing_mask
256+
257+ # Count occurrences of 0, 1, 2 per row
258+ # gn is produced by to_n_ref() so values are guaranteed to be
259+ # 0, 1, or 2 (ref allele count for biallelic sites), with -127
260+ # for missing calls.
261+ counts = np .stack (
262+ [
263+ np .sum ((arr == 0 ) & valid , axis = 1 ),
264+ np .sum ((arr == 1 ) & valid , axis = 1 ),
265+ np .sum ((arr == 2 ) & valid , axis = 1 ),
266+ ],
267+ axis = 1 ,
268+ )
269+
270+ # Determine mode per row
271+ site_modes = np .argmax (counts , axis = 1 )
272+
273+ # Handle rows where all values are missing
274+ all_missing = ~ valid .any (axis = 1 )
275+ site_modes [all_missing ] = 0
276+
277+ # Fill missing values
278+ fill_values = site_modes [np .where (missing_mask )[0 ]]
279+
266280 elif imputation_method == "mean" :
267281 site_means = np .where (
268282 np .all (missing_mask , axis = 1 , keepdims = True ),
0 commit comments