Skip to content

Commit 51561b2

Browse files
committed
fix: use imputation_method parameter in _pca imputation logic
1 parent 1df1084 commit 51561b2

1 file changed

Lines changed: 28 additions & 12 deletions

File tree

malariagen_data/anoph/pca.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def pca(
6969
max_missing_an: Optional[
7070
base_params.max_missing_an
7171
] = pca_params.max_missing_an_default,
72+
imputation_method: str = "mean",
7273
cohort_size: Optional[base_params.cohort_size] = None,
7374
min_cohort_size: Optional[base_params.min_cohort_size] = None,
7475
max_cohort_size: Optional[base_params.max_cohort_size] = None,
@@ -80,7 +81,7 @@ def pca(
8081
) -> Tuple[pca_params.df_pca, pca_params.evr]:
8182
# Change this name if you ever change the behaviour of this function, to
8283
# invalidate any previously cached data.
83-
name = "pca_v6"
84+
name = "pca_v7"
8485

8586
# Check that either sample_query xor sample_indices are provided.
8687
base_params._validate_sample_selection_params(
@@ -121,6 +122,7 @@ def pca(
121122
site_class=site_class,
122123
min_minor_ac=min_minor_ac,
123124
max_missing_an=max_missing_an,
125+
imputation_method=imputation_method,
124126
n_components=n_components,
125127
cohort_size=cohort_size,
126128
min_cohort_size=min_cohort_size,
@@ -185,6 +187,7 @@ def _pca(
185187
site_class,
186188
min_minor_ac,
187189
max_missing_an,
190+
imputation_method="mean",
188191
n_components,
189192
cohort_size,
190193
min_cohort_size,
@@ -231,22 +234,35 @@ def _pca(
231234
loc_keep_fit = np.ones(len(samples), dtype=bool)
232235
gn_fit = gn
233236

234-
# Impute missing calls (-127) with the mean value at each site.
237+
# Impute missing calls (-127) using the chosen imputation method.
235238
if max_missing_an is not None and max_missing_an > 0:
236239
gn_fit = gn_fit.astype(float)
237240
gn = gn.astype(float)
238241
for arr in [gn_fit, gn]:
239242
missing_mask = arr == -127
240-
site_means = np.where(
241-
np.all(missing_mask, axis=1, keepdims=True),
242-
0,
243-
np.nanmean(
244-
np.where(missing_mask, np.nan, arr), axis=1, keepdims=True
245-
),
246-
)
247-
arr[missing_mask] = np.take(
248-
site_means.flatten(), np.where(missing_mask)[0]
249-
)
243+
244+
if imputation_method == "mean":
245+
site_means = np.where(
246+
np.all(missing_mask, axis=1, keepdims=True),
247+
0,
248+
np.nanmean(
249+
np.where(missing_mask, np.nan, arr),
250+
axis=1,
251+
keepdims=True,
252+
),
253+
)
254+
fill_values = np.take(
255+
site_means.flatten(), np.where(missing_mask)[0]
256+
)
257+
elif imputation_method == "zero":
258+
fill_values = 0
259+
else:
260+
raise ValueError(
261+
f"Unknown imputation_method: {imputation_method!r}. "
262+
"Choose from 'mean' or 'zero'."
263+
)
264+
265+
arr[missing_mask] = fill_values
250266

251267
# Remove any sites where all genotypes are identical.
252268
loc_var = np.any(gn_fit != gn_fit[:, 0, np.newaxis], axis=1)

0 commit comments

Comments
 (0)