|
4 | 4 | import numpy as np |
5 | 5 | import xarray as xr |
6 | 6 | import allel |
7 | | -import dask.array as da |
8 | 7 | from numpydoc_decorator import doc # type: ignore |
9 | 8 |
|
10 | 9 | from ..util import check_types, haplotype_frequencies |
@@ -95,8 +94,8 @@ def haplotypes_frequencies( |
95 | 94 |
|
96 | 95 | n_samples = np.count_nonzero(loc_coh) |
97 | 96 | assert n_samples >= min_cohort_size |
98 | | - gt_coh = allel.GenotypeDaskArray(da.compress(loc_coh, gt, axis=1)) |
99 | | - gt_hap = gt_coh.to_haplotypes().compute() |
| 97 | + gt_coh = gt.compress(loc_coh, axis=1) |
| 98 | + gt_hap = gt_coh.to_haplotypes() |
100 | 99 | f, _, _ = haplotype_frequencies(gt_hap) |
101 | 100 | # The frequencies of the observed haplotypes are then updated |
102 | 101 | hap_dict.update(f) |
@@ -171,12 +170,6 @@ def haplotypes_frequencies_advanced( |
171 | 170 | min_cohort_size=min_cohort_size, |
172 | 171 | ) |
173 | 172 |
|
174 | | - # Early check for no cohorts. |
175 | | - if len(df_cohorts) == 0: |
176 | | - raise ValueError( |
177 | | - "No cohorts available for the given sample selection parameters and minimum cohort size." |
178 | | - ) |
179 | | - |
180 | 173 | # Access haplotypes. |
181 | 174 | ds_haps = self.haplotypes( |
182 | 175 | region=region, |
@@ -220,9 +213,8 @@ def haplotypes_frequencies_advanced( |
220 | 213 | hap_nob = {k: 2 * n_samples for k in f_all.keys()} |
221 | 214 | assert n_samples >= min_cohort_size |
222 | 215 | sample_indices = group_samples_by_cohort.indices[cohort_key] |
223 | | - loc_coh = [i in sample_indices for i in range(0, gt.shape[1])] |
224 | | - gt_coh = allel.GenotypeDaskArray(da.compress(loc_coh, gt, axis=1)) |
225 | | - gt_hap = gt_coh.to_haplotypes().compute() |
| 216 | + gt_coh = gt.take(sample_indices, axis=1) |
| 217 | + gt_hap = gt_coh.to_haplotypes() |
226 | 218 | f, c, o = haplotype_frequencies(gt_hap) |
227 | 219 | # The frequencies and counts of the observed haplotypes are then updated, so are the nobs but the values should actually stay the same |
228 | 220 | hap_freq.update(f) |
@@ -342,6 +334,12 @@ def _build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_s |
342 | 334 | # Apply minimum cohort size. |
343 | 335 | df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True) |
344 | 336 |
|
| 337 | + # Early check for no cohorts. |
| 338 | + if len(df_cohorts) == 0: |
| 339 | + raise ValueError( |
| 340 | + "No cohorts available for the given sample selection parameters and minimum cohort size." |
| 341 | + ) |
| 342 | + |
345 | 343 | return df_cohorts |
346 | 344 |
|
347 | 345 |
|
|
0 commit comments