@@ -69,6 +69,7 @@ def pca(
6969 max_missing_an : Optional [
7070 base_params .max_missing_an
7171 ] = pca_params .max_missing_an_default ,
72+ imputation_method : str = "mean" ,
7273 cohort_size : Optional [base_params .cohort_size ] = None ,
7374 min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
7475 max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
@@ -80,7 +81,7 @@ def pca(
8081 ) -> Tuple [pca_params .df_pca , pca_params .evr ]:
8182 # Change this name if you ever change the behaviour of this function, to
8283 # invalidate any previously cached data.
83- name = "pca_v6 "
84+ name = "pca_v7 "
8485
8586 # Check that either sample_query xor sample_indices are provided.
8687 base_params ._validate_sample_selection_params (
@@ -121,6 +122,7 @@ def pca(
121122 site_class = site_class ,
122123 min_minor_ac = min_minor_ac ,
123124 max_missing_an = max_missing_an ,
125+ imputation_method = imputation_method ,
124126 n_components = n_components ,
125127 cohort_size = cohort_size ,
126128 min_cohort_size = min_cohort_size ,
@@ -185,6 +187,7 @@ def _pca(
185187 site_class ,
186188 min_minor_ac ,
187189 max_missing_an ,
190+ imputation_method = "mean" ,
188191 n_components ,
189192 cohort_size ,
190193 min_cohort_size ,
@@ -231,22 +234,35 @@ def _pca(
231234 loc_keep_fit = np .ones (len (samples ), dtype = bool )
232235 gn_fit = gn
233236
234- # Impute missing calls (-127) with the mean value at each site .
237+ # Impute missing calls (-127) using the chosen imputation method .
235238 if max_missing_an is not None and max_missing_an > 0 :
236239 gn_fit = gn_fit .astype (float )
237240 gn = gn .astype (float )
238241 for arr in [gn_fit , gn ]:
239242 missing_mask = arr == - 127
240- site_means = np .where (
241- np .all (missing_mask , axis = 1 , keepdims = True ),
242- 0 ,
243- np .nanmean (
244- np .where (missing_mask , np .nan , arr ), axis = 1 , keepdims = True
245- ),
246- )
247- arr [missing_mask ] = np .take (
248- site_means .flatten (), np .where (missing_mask )[0 ]
249- )
243+
244+ if imputation_method == "mean" :
245+ site_means = np .where (
246+ np .all (missing_mask , axis = 1 , keepdims = True ),
247+ 0 ,
248+ np .nanmean (
249+ np .where (missing_mask , np .nan , arr ),
250+ axis = 1 ,
251+ keepdims = True ,
252+ ),
253+ )
254+ fill_values = np .take (
255+ site_means .flatten (), np .where (missing_mask )[0 ]
256+ )
257+ elif imputation_method == "zero" :
258+ fill_values = 0
259+ else :
260+ raise ValueError (
261+ f"Unknown imputation_method: { imputation_method !r} . "
262+ "Choose from 'mean' or 'zero'."
263+ )
264+
265+ arr [missing_mask ] = fill_values
250266
251267 # Remove any sites where all genotypes are identical.
252268 loc_var = np .any (gn_fit != gn_fit [:, 0 , np .newaxis ], axis = 1 )
0 commit comments