11# Standard library imports.
2- from typing import Optional , Tuple
2+ from typing import Any , Optional , Tuple
33import math
44
55# Third-party library imports.
@@ -86,7 +86,12 @@ def __init__(self, **kwargs):
8686 summary = """
8787 Compute pairwise distances between samples using biallelic SNP genotypes.
8888 """ ,
89- returns = ("dist" , "samples" , "n_snps_used" ),
89+ returns = """
90+ If `return_dataset` is False (default), return a tuple
91+ `(dist, samples, n_snps_used)`. If `return_dataset` is True,
92+ return an xarray Dataset with `dist`, `sample_id`, and
93+ `n_snps_used` as variables/coordinates.
94+ """ ,
9095 )
9196 def biallelic_diplotype_pairwise_distances (
9297 self ,
@@ -108,9 +113,8 @@ def biallelic_diplotype_pairwise_distances(
108113 random_seed : base_params .random_seed = 42 ,
109114 inline_array : base_params .inline_array = base_params .inline_array_default ,
110115 chunks : base_params .chunks = base_params .native_chunks ,
111- ) -> Tuple [
112- distance_params .dist , distance_params .samples , distance_params .n_snps_used
113- ]:
116+ return_dataset : base_params .return_dataset = False ,
117+ ) -> Any :
114118 # Change this name if you ever change the behaviour of this function, to
115119 # invalidate any previously cached data.
116120 name = "biallelic_diplotype_pairwise_distances"
@@ -173,6 +177,22 @@ def biallelic_diplotype_pairwise_distances(
173177 samples : np .ndarray = results ["samples" ]
174178 n_snps_used : int = int (results ["n_snps" ][()]) # ensure scalar
175179
180+ if return_dataset :
181+ import xarray as xr
182+ from scipy .spatial .distance import squareform
183+
184+ dist_square = squareform (dist )
185+ ds = xr .Dataset (
186+ data_vars = {
187+ "dist" : (("sample_x" , "sample_y" ), dist_square ),
188+ },
189+ coords = {
190+ "sample_id" : ("sample_x" , samples ),
191+ },
192+ attrs = {"n_snps_used" : n_snps_used },
193+ )
194+ return ds
195+
176196 return dist , samples , n_snps_used
177197
178198 def _biallelic_diplotype_pairwise_distances (
@@ -195,7 +215,7 @@ def _biallelic_diplotype_pairwise_distances(
195215 max_missing_an ,
196216 ):
197217 # Compute diplotypes.
198- gn , samples = self .biallelic_diplotypes (
218+ ds = self .biallelic_diplotypes (
199219 region = region ,
200220 sample_sets = sample_sets ,
201221 sample_indices = sample_indices ,
@@ -211,7 +231,10 @@ def _biallelic_diplotype_pairwise_distances(
211231 min_minor_ac = min_minor_ac ,
212232 n_snps = n_snps ,
213233 thin_offset = thin_offset ,
234+ return_dataset = True ,
214235 )
236+ gn = ds ["call_diplotype" ].values
237+ samples = ds ["sample_id" ].values .astype ("U" )
215238
216239 # Record number of SNPs used.
217240 n_snps = gn .shape [0 ]
0 commit comments