@@ -44,6 +44,14 @@ def __init__(
4444 `random_seed`.
4545
4646 """ ,
47+ parameters = dict (
48+ imputation_method = """
49+ Method to use for imputing missing genotype calls. Options are
50+ 'most_common' (replace missing calls with the most common genotype at each site,
51+ the default), 'mean' (replace missing calls with the
52+ mean value at each site), or 'zero' (replace missing calls with zero).
53+ """ ,
54+ ),
4755 returns = ("df_pca" , "evr" ),
4856 notes = """
4957 This computation may take some time to run, depending on your computing
@@ -69,6 +77,7 @@ def pca(
6977 max_missing_an : Optional [
7078 base_params .max_missing_an
7179 ] = pca_params .max_missing_an_default ,
80+ imputation_method : pca_params .imputation_method = pca_params .imputation_method_default ,
7281 cohort_size : Optional [base_params .cohort_size ] = None ,
7382 min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
7483 max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
@@ -80,7 +89,7 @@ def pca(
8089 ) -> Tuple [pca_params .df_pca , pca_params .evr ]:
8190 # Change this name if you ever change the behaviour of this function, to
8291 # invalidate any previously cached data.
83- name = "pca_v5 "
92+ name = "pca_v8 "
8493
8594 # Check that either sample_query xor sample_indices are provided.
8695 base_params ._validate_sample_selection_params (
@@ -121,6 +130,7 @@ def pca(
121130 site_class = site_class ,
122131 min_minor_ac = min_minor_ac ,
123132 max_missing_an = max_missing_an ,
133+ imputation_method = imputation_method ,
124134 n_components = n_components ,
125135 cohort_size = cohort_size ,
126136 min_cohort_size = min_cohort_size ,
@@ -152,7 +162,7 @@ def pca(
152162 # df_pca.index = df_pca.index.astype(str)
153163
154164 # Name the DataFrame's columns as PC1, PC2, etc.
155- df_pca .columns = pd .Index ([f"PC{ i + 1 } " for i in range (coords .shape [1 ])])
165+ df_pca .columns = pd .Index ([f"PC{ i + 1 } " for i in range (coords .shape [1 ])])
156166
157167 # Load the sample metadata.
158168 df_samples = self .sample_metadata (
@@ -185,6 +195,7 @@ def _pca(
185195 site_class ,
186196 min_minor_ac ,
187197 max_missing_an ,
198+ imputation_method = "most_common" ,
188199 n_components ,
189200 cohort_size ,
190201 min_cohort_size ,
@@ -231,6 +242,50 @@ def _pca(
231242 loc_keep_fit = np .ones (len (samples ), dtype = bool )
232243 gn_fit = gn
233244
245+ # Impute missing calls (-127) using the chosen imputation method.
246+ if max_missing_an is not None and max_missing_an > 0 :
247+ gn_fit = gn_fit .astype (float )
248+ gn = gn .astype (float )
249+ for arr in [gn_fit , gn ]:
250+ missing_mask = arr == - 127
251+
252+ if imputation_method == "most_common" :
253+ # For each site, find the most common non-missing value.
254+ site_modes = []
255+ for row in arr :
256+ non_missing = row [row != - 127 ]
257+ if len (non_missing ) == 0 :
258+ site_modes .append (0 )
259+ else :
260+ values , counts = np .unique (
261+ non_missing , return_counts = True
262+ )
263+ site_modes .append (values [np .argmax (counts )])
264+ site_modes = np .array (site_modes )
265+ fill_values = np .take (site_modes , np .where (missing_mask )[0 ])
266+ elif imputation_method == "mean" :
267+ site_means = np .where (
268+ np .all (missing_mask , axis = 1 , keepdims = True ),
269+ 0 ,
270+ np .nanmean (
271+ np .where (missing_mask , np .nan , arr ),
272+ axis = 1 ,
273+ keepdims = True ,
274+ ),
275+ )
276+ fill_values = np .take (
277+ site_means .flatten (), np .where (missing_mask )[0 ]
278+ )
279+ elif imputation_method == "zero" :
280+ fill_values = 0
281+ else :
282+ raise ValueError (
283+ f"Unknown imputation_method: { imputation_method !r} . "
284+ "Choose from 'most_common', 'mean' or 'zero'."
285+ )
286+
287+ arr [missing_mask ] = fill_values
288+
234289 # Remove any sites where all genotypes are identical.
235290 loc_var = np .any (gn_fit != gn_fit [:, 0 , np .newaxis ], axis = 1 )
236291 gn_fit_var = np .compress (loc_var , gn_fit , axis = 0 )
0 commit comments