@@ -47,8 +47,9 @@ def __init__(
4747 parameters = dict (
4848 imputation_method = """
4949 Method to use for imputing missing genotype calls. Options are
50- 'mean' (replace missing calls with the mean value at each site,
51- the default) or 'zero' (replace missing calls with zero).
50+ 'most_common' (replace missing calls with the most common genotype at each site,
51+ the default), 'mean' (replace missing calls with the
52+ mean value at each site), or 'zero' (replace missing calls with zero).
5253 """ ,
5354 ),
5455 returns = ("df_pca" , "evr" ),
@@ -76,7 +77,7 @@ def pca(
7677 max_missing_an : Optional [
7778 base_params .max_missing_an
7879 ] = pca_params .max_missing_an_default ,
79- imputation_method : str = "mean" ,
80+ imputation_method : pca_params . imputation_method = pca_params . imputation_method_default ,
8081 cohort_size : Optional [base_params .cohort_size ] = None ,
8182 min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
8283 max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
@@ -88,7 +89,7 @@ def pca(
8889 ) -> Tuple [pca_params .df_pca , pca_params .evr ]:
8990 # Change this name if you ever change the behaviour of this function, to
9091 # invalidate any previously cached data.
91- name = "pca_v7 "
92+ name = "pca_v8 "
9293
9394 # Check that either sample_query xor sample_indices are provided.
9495 base_params ._validate_sample_selection_params (
@@ -194,7 +195,7 @@ def _pca(
194195 site_class ,
195196 min_minor_ac ,
196197 max_missing_an ,
197- imputation_method = "mean " ,
198+ imputation_method = "most_common " ,
198199 n_components ,
199200 cohort_size ,
200201 min_cohort_size ,
@@ -248,7 +249,19 @@ def _pca(
248249 for arr in [gn_fit , gn ]:
249250 missing_mask = arr == - 127
250251
251- if imputation_method == "mean" :
252+ if imputation_method == "most_common" :
253+ # For each site, find the most common non-missing value.
254+ site_modes = []
255+ for row in arr :
256+ non_missing = row [row != - 127 ]
257+ if len (non_missing ) == 0 :
258+ site_modes .append (0 )
259+ else :
260+ values , counts = np .unique (non_missing , return_counts = True )
261+ site_modes .append (values [np .argmax (counts )])
262+ site_modes = np .array (site_modes )
263+ fill_values = np .take (site_modes , np .where (missing_mask )[0 ])
264+ elif imputation_method == "mean" :
252265 site_means = np .where (
253266 np .all (missing_mask , axis = 1 , keepdims = True ),
254267 0 ,
@@ -266,7 +279,7 @@ def _pca(
266279 else :
267280 raise ValueError (
268281 f"Unknown imputation_method: { imputation_method !r} . "
269- "Choose from 'mean' or 'zero'."
282+ "Choose from 'most_common', ' mean' or 'zero'."
270283 )
271284
272285 arr [missing_mask ] = fill_values
0 commit comments