|
3 | 3 | import pandas as pd |
4 | 4 | import numpy as np |
5 | 5 | import xarray as xr |
| 6 | +import allel |
| 7 | +import dask.array as da |
6 | 8 | from hashlib import sha1 |
7 | 9 | from numpydoc_decorator import doc # type: ignore |
8 | 10 |
|
9 | | -from ..util import check_types |
| 11 | +from ..util import check_types, haplotype_frequencies |
10 | 12 | from .hap_data import AnophelesHapData |
11 | 13 | from .sample_metadata import locate_cohorts |
12 | 14 | from . import base_params, frq_params # , map_params, plotly_params |
@@ -40,7 +42,7 @@ def __init__( |
40 | 42 | output data frame. |
41 | 43 | """, |
42 | 44 | ) |
43 | | - def haplotype_frequencies( |
| 45 | + def haplotypes_frequencies( |
44 | 46 | self, |
45 | 47 | region: base_params.region, |
46 | 48 | cohorts: base_params.cohorts, |
@@ -85,38 +87,31 @@ def haplotype_frequencies( |
85 | 87 | raise ValueError("No SNPs available for the given region.") |
86 | 88 |
|
87 | 89 | # Access genotypes. |
88 | | - gt = ds_hap["call_genotype"].data |
| 90 | + gt = allel.GenotypeDaskArray(ds_hap["call_genotype"].data) |
89 | 91 | with self._dask_progress(desc="Compute haplotypes"): |
90 | 92 | gt = gt.compute() |
91 | 93 |
|
92 | 94 | # Count haplotypes. |
93 | | - count_rows: dict[str, int] = dict() |
94 | | - freq_rows = dict() |
95 | 95 | freq_cols = dict() |
96 | 96 | cohorts_iterator = self._progress( |
97 | 97 | coh_dict.items(), desc="Compute allele frequencies" |
98 | 98 | ) |
| 99 | + hap_track = {} |
99 | 100 | for coh, loc_coh in cohorts_iterator: |
100 | | - count_rows = {k: 0 for k in count_rows.keys()} |
| 101 | + hap_track = {k: 0 for k in hap_track.keys()} |
101 | 102 | n_samples = np.count_nonzero(loc_coh) |
102 | 103 | assert n_samples >= min_cohort_size |
103 | | - gt_coh = np.compress(loc_coh, gt, axis=1).copy(order="C") |
104 | | - for i in range(0, n_samples): |
105 | | - for j in range(0, 2): |
106 | | - gt_cont = np.ascontiguousarray(gt_coh[:, i, j]) |
107 | | - hap_hash = str(sha1(gt_cont).digest()) |
108 | | - if hap_hash not in count_rows.keys(): |
109 | | - count_rows[hap_hash] = 1 |
110 | | - else: |
111 | | - count_rows[hap_hash] += 1 |
112 | | - freq_rows = {k: i / (2 * n_samples) for k, i in count_rows.items()} |
113 | | - freq_cols["frq_" + coh] = list(freq_rows.values()) |
| 104 | + gt_coh = allel.GenotypeDaskArray(da.compress(loc_coh, gt, axis=1)) |
| 105 | + gt_hap = gt_coh.to_haplotypes().compute() |
| 106 | + f, _, _ = haplotype_frequencies(gt_hap) |
| 107 | + hap_track.update(f) |
| 108 | + freq_cols["frq_" + coh] = list(hap_track.values()) |
114 | 109 |
|
115 | 110 | n_haps = np.max([len(i) for i in freq_cols.values()]) |
116 | 111 | freq_cols = { |
117 | 112 | k: v + [0 for i in range(0, n_haps - len(v))] for k, v in freq_cols.items() |
118 | 113 | } |
119 | | - df_freqs = pd.DataFrame(freq_cols, index=freq_rows.keys()) |
| 114 | + df_freqs = pd.DataFrame(freq_cols, index=hap_track.keys()) |
120 | 115 |
|
121 | 116 | # Compute max_af. |
122 | 117 | df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)}) |
@@ -149,7 +144,7 @@ def haplotype_frequencies( |
149 | 144 | counts and frequency calculations. |
150 | 145 | """, |
151 | 146 | ) |
152 | | - def haplotype_frequencies_advanced( |
| 147 | + def haplotypes_frequencies_advanced( |
153 | 148 | self, |
154 | 149 | region: base_params.region, |
155 | 150 | area_by: frq_params.area_by, |
|
0 commit comments