Skip to content

Commit 6b8c6ff

Browse files
committed
fix: address review feedback - use cache for dataset returns, add missing functions
1 parent f3afe7e commit 6b8c6ff

5 files changed

Lines changed: 125 additions & 141 deletions

File tree

malariagen_data/anoph/base_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,12 @@ def _validate_sample_selection_params(
326326
to select SNPs to be included
327327
""",
328328
]
329+
330+
return_dataset: TypeAlias = Annotated[
331+
bool,
332+
"""
333+
If True, return an xarray Dataset containing computed results as
334+
additional data variables. If False (default), return the legacy
335+
format (numpy array or tuple) for backward compatibility.
336+
""",
337+
]

malariagen_data/anoph/distance.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Standard library imports.
2-
from typing import Optional, Tuple
2+
from typing import Any, Optional, Tuple
33
import math
44

55
# Third-party library imports.
@@ -86,7 +86,12 @@ def __init__(self, **kwargs):
8686
summary="""
8787
Compute pairwise distances between samples using biallelic SNP genotypes.
8888
""",
89-
returns=("dist", "samples", "n_snps_used"),
89+
returns="""
90+
If `return_dataset` is False (default), return a tuple
91+
`(dist, samples, n_snps_used)`. If `return_dataset` is True,
92+
return an xarray Dataset with `dist`, `sample_id`, and
93+
`n_snps_used` as variables/coordinates.
94+
""",
9095
)
9196
def biallelic_diplotype_pairwise_distances(
9297
self,
@@ -108,22 +113,14 @@ def biallelic_diplotype_pairwise_distances(
108113
random_seed: base_params.random_seed = 42,
109114
inline_array: base_params.inline_array = base_params.inline_array_default,
110115
chunks: base_params.chunks = base_params.native_chunks,
111-
) -> Tuple[
112-
distance_params.dist, distance_params.samples, distance_params.n_snps_used
113-
]:
114-
# Change this name if you ever change the behaviour of this function, to
115-
# invalidate any previously cached data.
116+
return_dataset: base_params.return_dataset = False,
117+
) -> Any:
116118
name = "biallelic_diplotype_pairwise_distances"
117119

118-
# Check that either sample_query xor sample_indices are provided.
119120
base_params._validate_sample_selection_params(
120121
sample_query=sample_query, sample_indices=sample_indices
121122
)
122123

123-
## Normalize params for consistent hash value.
124-
125-
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
126-
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
127124
(
128125
sample_sets_prepped,
129126
sample_indices_prepped,
@@ -158,7 +155,6 @@ def biallelic_diplotype_pairwise_distances(
158155
max_missing_an=max_missing_an,
159156
)
160157

161-
# Try to retrieve results from the cache.
162158
try:
163159
results = self.results_cache_get(name=name, params=params)
164160

@@ -168,10 +164,24 @@ def biallelic_diplotype_pairwise_distances(
168164
)
169165
self.results_cache_set(name=name, params=params, results=results)
170166

171-
# Unpack results.
172167
dist: np.ndarray = results["dist"]
173168
samples: np.ndarray = results["samples"]
174-
n_snps_used: int = int(results["n_snps"][()]) # ensure scalar
169+
n_snps_used: int = int(results["n_snps"][()])
170+
171+
if return_dataset:
172+
import xarray as xr
173+
from scipy.spatial.distance import squareform
174+
dist_square = squareform(dist)
175+
ds = xr.Dataset(
176+
data_vars={
177+
"dist": (("sample_x", "sample_y"), dist_square),
178+
},
179+
coords={
180+
"sample_id": ("sample_x", samples),
181+
},
182+
attrs={"n_snps_used": n_snps_used},
183+
)
184+
return ds
175185

176186
return dist, samples, n_snps_used
177187

malariagen_data/anoph/hapclust.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, Tuple
2+
from typing import Any, Optional, Tuple
33

44
import allel # type: ignore
55
import numpy as np
@@ -204,11 +204,12 @@ def plot_haplotype_clustering(
204204
summary="""
205205
Compute pairwise distances between haplotypes.
206206
""",
207-
returns=dict(
208-
dist="Pairwise distance.",
209-
phased_samples="Sample identifiers for haplotypes.",
210-
n_snps="Number of SNPs used.",
211-
),
207+
returns="""
208+
If `return_dataset` is False (default), return a tuple
209+
`(dist, phased_samples, n_snps)`. If `return_dataset` is True,
210+
return an xarray Dataset with `dist`, `sample_id`, and
211+
`n_snps` as variables/attributes.
212+
""",
212213
)
213214
def haplotype_pairwise_distances(
214215
self,
@@ -222,12 +223,10 @@ def haplotype_pairwise_distances(
222223
random_seed: base_params.random_seed = 42,
223224
chunks: base_params.chunks = base_params.native_chunks,
224225
inline_array: base_params.inline_array = base_params.inline_array_default,
225-
) -> Tuple[np.ndarray, np.ndarray, int]:
226-
# Change this name if you ever change the behaviour of this function, to
227-
# invalidate any previously cached data.
226+
return_dataset: base_params.return_dataset = False,
227+
) -> Any:
228228
name = "haplotype_pairwise_distances"
229229

230-
# Normalize params for consistent hash value.
231230
sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
232231
del sample_sets
233232
sample_query_prepped = self._prep_sample_query_param(sample_query=sample_query)
@@ -245,7 +244,6 @@ def haplotype_pairwise_distances(
245244
random_seed=random_seed,
246245
)
247246

248-
# Try to retrieve results from the cache.
249247
try:
250248
results = self.results_cache_get(name=name, params=params)
251249

@@ -255,10 +253,24 @@ def haplotype_pairwise_distances(
255253
)
256254
self.results_cache_set(name=name, params=params, results=results)
257255

258-
# Unpack results")
259256
dist: np.ndarray = results["dist"]
260257
phased_samples: np.ndarray = results["phased_samples"]
261-
n_snps: int = int(results["n_snps"][()]) # ensure scalar
258+
n_snps: int = int(results["n_snps"][()])
259+
260+
if return_dataset:
261+
import xarray as xr
262+
from scipy.spatial.distance import squareform
263+
dist_square = squareform(dist)
264+
ds = xr.Dataset(
265+
data_vars={
266+
"dist": (("sample_x", "sample_y"), dist_square),
267+
},
268+
coords={
269+
"sample_id": ("sample_x", phased_samples),
270+
},
271+
attrs={"n_snps": n_snps},
272+
)
273+
return ds
262274

263275
return dist, phased_samples, n_snps
264276

0 commit comments

Comments
 (0)