Skip to content

Commit 3f848ef

Browse files
committed
fix: add imputation_method to pca_params.py and bump cache to pca_v8
1 parent 726d781 commit 3f848ef

2 files changed

Lines changed: 29 additions & 7 deletions

File tree

malariagen_data/anoph/pca.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def __init__(
4747
parameters=dict(
4848
imputation_method="""
4949
Method to use for imputing missing genotype calls. Options are
50-
'mean' (replace missing calls with the mean value at each site,
51-
the default) or 'zero' (replace missing calls with zero).
50+
'most_common' (replace missing calls with the most common genotype at each site,
51+
the default), 'mean' (replace missing calls with the
52+
mean value at each site), or 'zero' (replace missing calls with zero).
5253
""",
5354
),
5455
returns=("df_pca", "evr"),
@@ -76,7 +77,7 @@ def pca(
7677
max_missing_an: Optional[
7778
base_params.max_missing_an
7879
] = pca_params.max_missing_an_default,
79-
imputation_method: str = "mean",
80+
imputation_method: pca_params.imputation_method = pca_params.imputation_method_default,
8081
cohort_size: Optional[base_params.cohort_size] = None,
8182
min_cohort_size: Optional[base_params.min_cohort_size] = None,
8283
max_cohort_size: Optional[base_params.max_cohort_size] = None,
@@ -88,7 +89,7 @@ def pca(
8889
) -> Tuple[pca_params.df_pca, pca_params.evr]:
8990
# Change this name if you ever change the behaviour of this function, to
9091
# invalidate any previously cached data.
91-
name = "pca_v7"
92+
name = "pca_v8"
9293

9394
# Check that either sample_query xor sample_indices are provided.
9495
base_params._validate_sample_selection_params(
@@ -194,7 +195,7 @@ def _pca(
194195
site_class,
195196
min_minor_ac,
196197
max_missing_an,
197-
imputation_method="mean",
198+
imputation_method="most_common",
198199
n_components,
199200
cohort_size,
200201
min_cohort_size,
@@ -248,7 +249,19 @@ def _pca(
248249
for arr in [gn_fit, gn]:
249250
missing_mask = arr == -127
250251

251-
if imputation_method == "mean":
252+
if imputation_method == "most_common":
253+
# For each site, find the most common non-missing value.
254+
site_modes = []
255+
for row in arr:
256+
non_missing = row[row != -127]
257+
if len(non_missing) == 0:
258+
site_modes.append(0)
259+
else:
260+
values, counts = np.unique(non_missing, return_counts=True)
261+
site_modes.append(values[np.argmax(counts)])
262+
site_modes = np.array(site_modes)
263+
fill_values = np.take(site_modes, np.where(missing_mask)[0])
264+
elif imputation_method == "mean":
252265
site_means = np.where(
253266
np.all(missing_mask, axis=1, keepdims=True),
254267
0,
@@ -266,7 +279,7 @@ def _pca(
266279
else:
267280
raise ValueError(
268281
f"Unknown imputation_method: {imputation_method!r}. "
269-
"Choose from 'mean' or 'zero'."
282+
"Choose from 'most_common', 'mean' or 'zero'."
270283
)
271284

272285
arr[missing_mask] = fill_values

malariagen_data/anoph/pca_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,12 @@
8686
min_minor_ac_default: base_params.min_minor_ac = 2
8787

8888
max_missing_an_default: base_params.max_missing_an = 0
89+
90+
imputation_method: TypeAlias = Annotated[
91+
str,
92+
"Method to use for imputing missing genotype calls when max_missing_an > 0. "
93+
"Options are 'most_common' (replace missing calls with the most common genotype "
94+
"at each site), 'mean' (replace with the site mean), or 'zero' (replace with zero).",
95+
]
96+
97+
imputation_method_default: imputation_method = "most_common"

0 commit comments

Comments
 (0)