Skip to content

Commit 27ac08c

Browse files
authored
Merge pull request #975 from Tanisha127/standardise-biallelic-diplotypes-471
refactor: standardise biallelic diplotypes and handling of missing calls
2 parents 51bd38f + 8e84a28 commit 27ac08c

5 files changed

Lines changed: 78 additions & 7 deletions

File tree

malariagen_data/anoph/distance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def _biallelic_diplotype_pairwise_distances(
217217
n_snps = gn.shape[0]
218218

219219
# Prepare data for pairwise distance calculation.
220+
# Mask missing calls (-127) before computing distances.
221+
gn = gn.astype(float)
222+
gn[gn == -127] = np.nan
220223
X = np.ascontiguousarray(gn.T)
221224

222225
# Look up distance function.

malariagen_data/anoph/pca.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def __init__(
4444
`random_seed`.
4545
4646
""",
47+
parameters=dict(
48+
imputation_method="""
49+
Method to use for imputing missing genotype calls. Options are
50+
'most_common' (replace missing calls with the most common genotype at each site,
51+
the default), 'mean' (replace missing calls with the
52+
mean value at each site), or 'zero' (replace missing calls with zero).
53+
""",
54+
),
4755
returns=("df_pca", "evr"),
4856
notes="""
4957
This computation may take some time to run, depending on your computing
@@ -69,6 +77,7 @@ def pca(
6977
max_missing_an: Optional[
7078
base_params.max_missing_an
7179
] = pca_params.max_missing_an_default,
80+
imputation_method: pca_params.imputation_method = pca_params.imputation_method_default,
7281
cohort_size: Optional[base_params.cohort_size] = None,
7382
min_cohort_size: Optional[base_params.min_cohort_size] = None,
7483
max_cohort_size: Optional[base_params.max_cohort_size] = None,
@@ -80,7 +89,7 @@ def pca(
8089
) -> Tuple[pca_params.df_pca, pca_params.evr]:
8190
# Change this name if you ever change the behaviour of this function, to
8291
# invalidate any previously cached data.
83-
name = "pca_v5"
92+
name = "pca_v8"
8493

8594
# Check that either sample_query xor sample_indices are provided.
8695
base_params._validate_sample_selection_params(
@@ -121,6 +130,7 @@ def pca(
121130
site_class=site_class,
122131
min_minor_ac=min_minor_ac,
123132
max_missing_an=max_missing_an,
133+
imputation_method=imputation_method,
124134
n_components=n_components,
125135
cohort_size=cohort_size,
126136
min_cohort_size=min_cohort_size,
@@ -152,7 +162,7 @@ def pca(
152162
# df_pca.index = df_pca.index.astype(str)
153163

154164
# Name the DataFrame's columns as PC1, PC2, etc.
155-
df_pca.columns = pd.Index([f"PC{i+1}" for i in range(coords.shape[1])])
165+
df_pca.columns = pd.Index([f"PC{i + 1}" for i in range(coords.shape[1])])
156166

157167
# Load the sample metadata.
158168
df_samples = self.sample_metadata(
@@ -185,6 +195,7 @@ def _pca(
185195
site_class,
186196
min_minor_ac,
187197
max_missing_an,
198+
imputation_method="most_common",
188199
n_components,
189200
cohort_size,
190201
min_cohort_size,
@@ -231,6 +242,50 @@ def _pca(
231242
loc_keep_fit = np.ones(len(samples), dtype=bool)
232243
gn_fit = gn
233244

245+
# Impute missing calls (-127) using the chosen imputation method.
246+
if max_missing_an is not None and max_missing_an > 0:
247+
gn_fit = gn_fit.astype(float)
248+
gn = gn.astype(float)
249+
for arr in [gn_fit, gn]:
250+
missing_mask = arr == -127
251+
252+
if imputation_method == "most_common":
253+
# For each site, find the most common non-missing value.
254+
site_modes = []
255+
for row in arr:
256+
non_missing = row[row != -127]
257+
if len(non_missing) == 0:
258+
site_modes.append(0)
259+
else:
260+
values, counts = np.unique(
261+
non_missing, return_counts=True
262+
)
263+
site_modes.append(values[np.argmax(counts)])
264+
site_modes = np.array(site_modes)
265+
fill_values = np.take(site_modes, np.where(missing_mask)[0])
266+
elif imputation_method == "mean":
267+
site_means = np.where(
268+
np.all(missing_mask, axis=1, keepdims=True),
269+
0,
270+
np.nanmean(
271+
np.where(missing_mask, np.nan, arr),
272+
axis=1,
273+
keepdims=True,
274+
),
275+
)
276+
fill_values = np.take(
277+
site_means.flatten(), np.where(missing_mask)[0]
278+
)
279+
elif imputation_method == "zero":
280+
fill_values = 0
281+
else:
282+
raise ValueError(
283+
f"Unknown imputation_method: {imputation_method!r}. "
284+
"Choose from 'most_common', 'mean' or 'zero'."
285+
)
286+
287+
arr[missing_mask] = fill_values
288+
234289
# Remove any sites where all genotypes are identical.
235290
loc_var = np.any(gn_fit != gn_fit[:, 0, np.newaxis], axis=1)
236291
gn_fit_var = np.compress(loc_var, gn_fit, axis=0)

malariagen_data/anoph/pca_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,12 @@
8686
min_minor_ac_default: base_params.min_minor_ac = 2
8787

8888
max_missing_an_default: base_params.max_missing_an = 0
89+
90+
imputation_method: TypeAlias = Annotated[
91+
str,
92+
"Method to use for imputing missing genotype calls when max_missing_an > 0. "
93+
"Options are 'most_common' (replace missing calls with the most common genotype "
94+
"at each site), 'mean' (replace with the site mean), or 'zero' (replace with zero).",
95+
]
96+
97+
imputation_method_default: imputation_method = "most_common"

malariagen_data/anoph/snp_data.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,7 +1939,7 @@ def biallelic_diplotypes(
19391939
) -> Tuple[np.ndarray, np.ndarray]:
19401940
# Change this name if you ever change the behaviour of this function, to
19411941
# invalidate any previously cached data.
1942-
name = "biallelic_diplotypes"
1942+
name = "biallelic_diplotypes_v2"
19431943

19441944
# Check that either sample_query xor sample_indices are provided.
19451945
base_params._validate_sample_selection_params(
@@ -2046,8 +2046,12 @@ def _biallelic_diplotypes(
20462046
samples = ds["sample_id"].values.astype("U")
20472047

20482048
# Compute diplotypes as the number of alt alleles per genotype call.
2049+
# with missing calls coded as -127.
20492050
gt = allel.GenotypeDaskArray(ds["call_genotype"].data)
20502051
with self._dask_progress(desc="Compute biallelic diplotypes"):
2051-
gn = gt.to_n_alt().compute()
2052+
gn = gt.to_n_ref().compute()
2053+
# Code missing calls as -127.
2054+
missing = np.all(ds["call_genotype"].values == -1, axis=2)
2055+
gn[missing] = -127
20522056

20532057
return dict(samples=samples, gn=gn)

tests/anoph/test_snp_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,10 +1388,10 @@ def check_biallelic_snp_calls_and_diplotypes(
13881388
assert gn.ndim == 2
13891389
assert gn.shape[0] == ds.sizes["variants"]
13901390
assert gn.shape[1] == ds.sizes["samples"]
1391-
assert np.all(gn >= 0)
1392-
assert np.all(gn <= 2)
1391+
assert np.all((gn >= 0) | (gn == -127))
1392+
assert np.all((gn <= 2) | (gn == -127))
13931393
ac = ds["variant_allele_count"].values
1394-
assert np.all(np.sum(gn, axis=1) == ac[:, 1])
1394+
assert np.all(np.sum(np.where(gn == -127, 0, gn), axis=1) == ac[:, 0])
13951395
assert samples.ndim == 1
13961396
assert samples.shape[0] == gn.shape[1]
13971397
assert samples.tolist() == expected_samples

0 commit comments

Comments
 (0)