Skip to content

Commit 0f54741

Browse files
Merge branch 'master' into fix/issue-1280-vcf-performance
2 parents b87e50b + cc94c93 commit 0f54741

15 files changed

Lines changed: 1058 additions & 510 deletions

malariagen_data/anoph/base_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,12 @@ def _validate_sample_selection_params(
326326
to select SNPs to be included
327327
""",
328328
]
329+
330+
return_dataset: TypeAlias = Annotated[
331+
bool,
332+
"""
333+
If True, return an xarray Dataset containing computed results as
334+
additional data variables. If False (default), return the legacy
335+
format (numpy array or tuple) for backward compatibility.
336+
""",
337+
]

malariagen_data/anoph/distance.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Standard library imports.
2-
from typing import Optional, Tuple
2+
from typing import Any, Optional, Tuple
33
import math
44

55
# Third-party library imports.
@@ -86,7 +86,12 @@ def __init__(self, **kwargs):
8686
summary="""
8787
Compute pairwise distances between samples using biallelic SNP genotypes.
8888
""",
89-
returns=("dist", "samples", "n_snps_used"),
89+
returns="""
90+
If `return_dataset` is False (default), return a tuple
91+
`(dist, samples, n_snps_used)`. If `return_dataset` is True,
92+
return an xarray Dataset with `dist`, `sample_id`, and
93+
`n_snps_used` as variables/coordinates.
94+
""",
9095
)
9196
def biallelic_diplotype_pairwise_distances(
9297
self,
@@ -108,9 +113,8 @@ def biallelic_diplotype_pairwise_distances(
108113
random_seed: base_params.random_seed = 42,
109114
inline_array: base_params.inline_array = base_params.inline_array_default,
110115
chunks: base_params.chunks = base_params.native_chunks,
111-
) -> Tuple[
112-
distance_params.dist, distance_params.samples, distance_params.n_snps_used
113-
]:
116+
return_dataset: base_params.return_dataset = False,
117+
) -> Any:
114118
# Change this name if you ever change the behaviour of this function, to
115119
# invalidate any previously cached data.
116120
name = "biallelic_diplotype_pairwise_distances"
@@ -173,6 +177,22 @@ def biallelic_diplotype_pairwise_distances(
173177
samples: np.ndarray = results["samples"]
174178
n_snps_used: int = int(results["n_snps"][()]) # ensure scalar
175179

180+
if return_dataset:
181+
import xarray as xr
182+
from scipy.spatial.distance import squareform
183+
184+
dist_square = squareform(dist)
185+
ds = xr.Dataset(
186+
data_vars={
187+
"dist": (("sample_x", "sample_y"), dist_square),
188+
},
189+
coords={
190+
"sample_id": ("sample_x", samples),
191+
},
192+
attrs={"n_snps_used": n_snps_used},
193+
)
194+
return ds
195+
176196
return dist, samples, n_snps_used
177197

178198
def _biallelic_diplotype_pairwise_distances(
@@ -195,7 +215,7 @@ def _biallelic_diplotype_pairwise_distances(
195215
max_missing_an,
196216
):
197217
# Compute diplotypes.
198-
gn, samples = self.biallelic_diplotypes(
218+
ds = self.biallelic_diplotypes(
199219
region=region,
200220
sample_sets=sample_sets,
201221
sample_indices=sample_indices,
@@ -211,7 +231,10 @@ def _biallelic_diplotype_pairwise_distances(
211231
min_minor_ac=min_minor_ac,
212232
n_snps=n_snps,
213233
thin_offset=thin_offset,
234+
return_dataset=True,
214235
)
236+
gn = ds["call_diplotype"].values
237+
samples = ds["sample_id"].values.astype("U")
215238

216239
# Record number of SNPs used.
217240
n_snps = gn.shape[0]

malariagen_data/anoph/fst.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Tuple, Optional
23

34
import numpy as np
@@ -43,6 +44,8 @@ def _fst_gwss(
4344
inline_array,
4445
chunks,
4546
clip_min,
47+
min_snps_threshold,
48+
window_adjustment_factor,
4649
):
4750
# Compute allele counts.
4851
ac1 = self.snp_allele_counts(
@@ -81,6 +84,24 @@ def _fst_gwss(
8184
chunks=chunks,
8285
).compute()
8386

87+
n_snps = len(pos)
88+
if n_snps < min_snps_threshold:
89+
raise ValueError(
90+
f"Too few SNP sites ({n_snps}) available for Fst GWSS. "
91+
f"At least {min_snps_threshold} sites are required. "
92+
"Try a larger genomic region or different site selection criteria."
93+
)
94+
if window_size >= n_snps:
95+
adjusted_window_size = max(1, n_snps // window_adjustment_factor)
96+
warnings.warn(
97+
f"window_size ({window_size}) is >= the number of SNP sites "
98+
f"available ({n_snps}); automatically adjusting window_size to "
99+
f"{adjusted_window_size} (= {n_snps} // {window_adjustment_factor}).",
100+
UserWarning,
101+
stacklevel=2,
102+
)
103+
window_size = adjusted_window_size
104+
84105
with self._spinner(desc="Compute Fst"):
85106
with np.errstate(divide="ignore", invalid="ignore"):
86107
fst = allel.moving_hudson_fst(ac1, ac2, size=window_size)
@@ -96,8 +117,23 @@ def _fst_gwss(
96117
@doc(
97118
summary="""
98119
Run a Fst genome-wide scan to investigate genetic differentiation
99-
between two cohorts.
120+
between two cohorts. If window_size is >= the number of available
121+
SNP sites, a UserWarning is issued and window_size is automatically
122+
adjusted to number_of_snps // window_adjustment_factor. A ValueError
123+
is raised if the number of available SNP sites is below
124+
min_snps_threshold.
100125
""",
126+
parameters=dict(
127+
min_snps_threshold="""
128+
Minimum number of SNP sites required. If fewer sites are
129+
available a ValueError is raised.
130+
""",
131+
window_adjustment_factor="""
132+
If window_size is >= the number of available SNP sites,
133+
window_size is automatically set to
134+
number_of_snps // window_adjustment_factor.
135+
""",
136+
),
101137
returns=dict(
102138
x="An array containing the window centre point genomic positions",
103139
fst="An array with Fst statistic values for each window.",
@@ -123,6 +159,8 @@ def fst_gwss(
123159
inline_array: base_params.inline_array = base_params.inline_array_default,
124160
chunks: base_params.chunks = base_params.native_chunks,
125161
clip_min: fst_params.clip_min = 0.0,
162+
min_snps_threshold: fst_params.min_snps_threshold = 1000,
163+
window_adjustment_factor: fst_params.window_adjustment_factor = 10,
126164
) -> Tuple[np.ndarray, np.ndarray]:
127165
# Change this name if you ever change the behaviour of this function, to
128166
# invalidate any previously cached data.
@@ -147,7 +185,13 @@ def fst_gwss(
147185
results = self.results_cache_get(name=name, params=params)
148186

149187
except CacheMiss:
150-
results = self._fst_gwss(**params, inline_array=inline_array, chunks=chunks)
188+
results = self._fst_gwss(
189+
**params,
190+
inline_array=inline_array,
191+
chunks=chunks,
192+
min_snps_threshold=min_snps_threshold,
193+
window_adjustment_factor=window_adjustment_factor,
194+
)
151195
self.results_cache_set(name=name, params=params, results=results)
152196

153197
x = results["x"]

malariagen_data/anoph/fst_params.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
""",
3535
]
3636

37+
min_snps_threshold: TypeAlias = Annotated[
38+
int,
39+
"""
40+
Minimum number of SNP sites required for the Fst GWSS computation. If
41+
fewer sites are available, a ValueError is raised.
42+
""",
43+
]
44+
45+
window_adjustment_factor: TypeAlias = Annotated[
46+
int,
47+
"""
48+
If window_size is >= the number of available SNP sites, the window_size
49+
is automatically adjusted to number_of_snps // window_adjustment_factor.
50+
""",
51+
]
52+
3753
annotation: TypeAlias = Annotated[
3854
Optional[Literal["standard error", "Z score", "lower triangle"]],
3955
"""

malariagen_data/anoph/hapclust.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, Tuple
2+
from typing import Any, Optional
33

44
import allel # type: ignore
55
import numpy as np
@@ -204,11 +204,12 @@ def plot_haplotype_clustering(
204204
summary="""
205205
Compute pairwise distances between haplotypes.
206206
""",
207-
returns=dict(
208-
dist="Pairwise distance.",
209-
phased_samples="Sample identifiers for haplotypes.",
210-
n_snps="Number of SNPs used.",
211-
),
207+
returns="""
208+
If `return_dataset` is False (default), return a tuple
209+
`(dist, phased_samples, n_snps)`. If `return_dataset` is True,
210+
return an xarray Dataset with `dist`, `sample_id`, and
211+
`n_snps` as variables/attributes.
212+
""",
212213
)
213214
def haplotype_pairwise_distances(
214215
self,
@@ -222,7 +223,8 @@ def haplotype_pairwise_distances(
222223
random_seed: base_params.random_seed = 42,
223224
chunks: base_params.chunks = base_params.native_chunks,
224225
inline_array: base_params.inline_array = base_params.inline_array_default,
225-
) -> Tuple[np.ndarray, np.ndarray, int]:
226+
return_dataset: base_params.return_dataset = False,
227+
) -> Any:
226228
# Change this name if you ever change the behaviour of this function, to
227229
# invalidate any previously cached data.
228230
name = "haplotype_pairwise_distances"
@@ -255,11 +257,30 @@ def haplotype_pairwise_distances(
255257
)
256258
self.results_cache_set(name=name, params=params, results=results)
257259

258-
# Unpack results")
260+
# Unpack results.
259261
dist: np.ndarray = results["dist"]
260262
phased_samples: np.ndarray = results["phased_samples"]
261263
n_snps: int = int(results["n_snps"][()]) # ensure scalar
262264

265+
if return_dataset:
266+
import xarray as xr
267+
from scipy.spatial.distance import squareform
268+
269+
dist_square = squareform(dist)
270+
# Each phased sample contributes 2 haplotypes; create
271+
# haplotype-level labels to match the distance matrix.
272+
hap_labels = np.repeat(phased_samples, 2)
273+
ds = xr.Dataset(
274+
data_vars={
275+
"dist": (("sample_x", "sample_y"), dist_square),
276+
},
277+
coords={
278+
"sample_id": ("sample_x", hap_labels),
279+
},
280+
attrs={"n_snps": n_snps},
281+
)
282+
return ds
283+
263284
return dist, phased_samples, n_snps
264285

265286
def _haplotype_pairwise_distances(

malariagen_data/anoph/pca.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _pca(
207207
inline_array,
208208
):
209209
# Load diplotypes.
210-
gn, samples = self.biallelic_diplotypes(
210+
ds_diplotypes = self.biallelic_diplotypes(
211211
region=region,
212212
n_snps=n_snps,
213213
thin_offset=thin_offset,
@@ -223,7 +223,10 @@ def _pca(
223223
random_seed=random_seed,
224224
chunks=chunks,
225225
inline_array=inline_array,
226+
return_dataset=True,
226227
)
228+
gn = ds_diplotypes["call_diplotype"].values
229+
samples = ds_diplotypes["sample_id"].values.astype("U")
227230

228231
with self._spinner(desc="Compute PCA"):
229232
# Exclude any samples prior to computing PCA.

0 commit comments

Comments
 (0)