Skip to content

Commit ba7a7ed

Browse files
committed
Passed the tests. I am not going to check that it does what I want until I have had dinner.
1 parent ad6c811 commit ba7a7ed

1 file changed

Lines changed: 14 additions & 19 deletions

File tree

malariagen_data/anoph/hap_freq.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import pandas as pd
44
import numpy as np
55
import xarray as xr
6+
import allel
7+
import dask.array as da
68
from hashlib import sha1
79
from numpydoc_decorator import doc # type: ignore
810

9-
from ..util import check_types
11+
from ..util import check_types, haplotype_frequencies
1012
from .hap_data import AnophelesHapData
1113
from .sample_metadata import locate_cohorts
1214
from . import base_params, frq_params # , map_params, plotly_params
@@ -40,7 +42,7 @@ def __init__(
4042
output data frame.
4143
""",
4244
)
43-
def haplotype_frequencies(
45+
def haplotypes_frequencies(
4446
self,
4547
region: base_params.region,
4648
cohorts: base_params.cohorts,
@@ -85,38 +87,31 @@ def haplotype_frequencies(
8587
raise ValueError("No SNPs available for the given region.")
8688

8789
# Access genotypes.
88-
gt = ds_hap["call_genotype"].data
90+
gt = allel.GenotypeDaskArray(ds_hap["call_genotype"].data)
8991
with self._dask_progress(desc="Compute haplotypes"):
9092
gt = gt.compute()
9193

9294
# Count haplotypes.
93-
count_rows: dict[str, int] = dict()
94-
freq_rows = dict()
9595
freq_cols = dict()
9696
cohorts_iterator = self._progress(
9797
coh_dict.items(), desc="Compute allele frequencies"
9898
)
99+
hap_track = {}
99100
for coh, loc_coh in cohorts_iterator:
100-
count_rows = {k: 0 for k in count_rows.keys()}
101+
hap_track = {k: 0 for k in hap_track.keys()}
101102
n_samples = np.count_nonzero(loc_coh)
102103
assert n_samples >= min_cohort_size
103-
gt_coh = np.compress(loc_coh, gt, axis=1).copy(order="C")
104-
for i in range(0, n_samples):
105-
for j in range(0, 2):
106-
gt_cont = np.ascontiguousarray(gt_coh[:, i, j])
107-
hap_hash = str(sha1(gt_cont).digest())
108-
if hap_hash not in count_rows.keys():
109-
count_rows[hap_hash] = 1
110-
else:
111-
count_rows[hap_hash] += 1
112-
freq_rows = {k: i / (2 * n_samples) for k, i in count_rows.items()}
113-
freq_cols["frq_" + coh] = list(freq_rows.values())
104+
gt_coh = allel.GenotypeDaskArray(da.compress(loc_coh, gt, axis=1))
105+
gt_hap = gt_coh.to_haplotypes().compute()
106+
f, _, _ = haplotype_frequencies(gt_hap)
107+
hap_track.update(f)
108+
freq_cols["frq_" + coh] = list(hap_track.values())
114109

115110
n_haps = np.max([len(i) for i in freq_cols.values()])
116111
freq_cols = {
117112
k: v + [0 for i in range(0, n_haps - len(v))] for k, v in freq_cols.items()
118113
}
119-
df_freqs = pd.DataFrame(freq_cols, index=freq_rows.keys())
114+
df_freqs = pd.DataFrame(freq_cols, index=hap_track.keys())
120115

121116
# Compute max_af.
122117
df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)})
@@ -149,7 +144,7 @@ def haplotype_frequencies(
149144
counts and frequency calculations.
150145
""",
151146
)
152-
def haplotype_frequencies_advanced(
147+
def haplotypes_frequencies_advanced(
153148
self,
154149
region: base_params.region,
155150
area_by: frq_params.area_by,

0 commit comments

Comments
 (0)