11# Standard library imports.
2- from typing import Optional , Tuple
2+ from typing import Any , Optional , Tuple
33import math
44
55# Third-party library imports.
@@ -86,7 +86,12 @@ def __init__(self, **kwargs):
8686 summary = """
8787 Compute pairwise distances between samples using biallelic SNP genotypes.
8888 """ ,
89- returns = ("dist" , "samples" , "n_snps_used" ),
89+ returns = """
90+ If `return_dataset` is False (default), return a tuple
91+ `(dist, samples, n_snps_used)`. If `return_dataset` is True,
92+ return an xarray Dataset with `dist`, `sample_id`, and
93+ `n_snps_used` as variables/coordinates.
94+ """ ,
9095 )
9196 def biallelic_diplotype_pairwise_distances (
9297 self ,
@@ -108,22 +113,14 @@ def biallelic_diplotype_pairwise_distances(
108113 random_seed : base_params .random_seed = 42 ,
109114 inline_array : base_params .inline_array = base_params .inline_array_default ,
110115 chunks : base_params .chunks = base_params .native_chunks ,
111- ) -> Tuple [
112- distance_params .dist , distance_params .samples , distance_params .n_snps_used
113- ]:
114- # Change this name if you ever change the behaviour of this function, to
115- # invalidate any previously cached data.
116+ return_dataset : base_params .return_dataset = False ,
117+ ) -> Any :
116118 name = "biallelic_diplotype_pairwise_distances"
117119
118- # Check that either sample_query xor sample_indices are provided.
119120 base_params ._validate_sample_selection_params (
120121 sample_query = sample_query , sample_indices = sample_indices
121122 )
122123
123- ## Normalize params for consistent hash value.
124-
125- # Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
126- # So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
127124 (
128125 sample_sets_prepped ,
129126 sample_indices_prepped ,
@@ -158,7 +155,6 @@ def biallelic_diplotype_pairwise_distances(
158155 max_missing_an = max_missing_an ,
159156 )
160157
161- # Try to retrieve results from the cache.
162158 try :
163159 results = self .results_cache_get (name = name , params = params )
164160
@@ -168,10 +164,24 @@ def biallelic_diplotype_pairwise_distances(
168164 )
169165 self .results_cache_set (name = name , params = params , results = results )
170166
171- # Unpack results.
172167 dist : np .ndarray = results ["dist" ]
173168 samples : np .ndarray = results ["samples" ]
174- n_snps_used : int = int (results ["n_snps" ][()]) # ensure scalar
169+ n_snps_used : int = int (results ["n_snps" ][()])
170+
171+ if return_dataset :
172+ import xarray as xr
173+ from scipy .spatial .distance import squareform
174+ dist_square = squareform (dist )
175+ ds = xr .Dataset (
176+ data_vars = {
177+ "dist" : (("sample_x" , "sample_y" ), dist_square ),
178+ },
179+ coords = {
180+ "sample_id" : ("sample_x" , samples ),
181+ },
182+ attrs = {"n_snps_used" : n_snps_used },
183+ )
184+ return ds
175185
176186 return dist , samples , n_snps_used
177187
0 commit comments