55import xarray as xr
66import allel
77import dask .array as da
8- from hashlib import sha1
98from numpydoc_decorator import doc # type: ignore
109
1110from ..util import check_types , haplotype_frequencies
@@ -202,8 +201,9 @@ def haplotypes_frequencies_advanced(
202201 gt = gt .compute ()
203202
204203 # Count haplotypes.
205- count_rows : dict [str , int ] = dict ()
206- freq_rows = dict ()
204+ hap_freq : dict [np .int64 , int ] = dict ()
205+ hap_count : dict [np .int64 , int ] = dict ()
206+ hap_nob : dict [np .int64 , int ] = dict ()
207207 freq_cols = dict ()
208208 count_cols = dict ()
209209 nobs_cols = dict ()
@@ -213,43 +213,41 @@ def haplotypes_frequencies_advanced(
213213 for cohort in cohorts_iterator :
214214 cohort_key = cohort .taxon , cohort .area , cohort .period
215215 cohort_key_str = cohort .taxon + "_" + cohort .area + "_" + str (cohort .period )
216- count_rows = {k : 0 for k in count_rows .keys ()}
216+ hap_freq = {k : 0 for k in hap_freq .keys ()}
217+ hap_count = {k : 0 for k in hap_count .keys ()}
218+ hap_nob = {k : 0 for k in hap_nob .keys ()}
217219 n_samples = cohort .size
218220 assert n_samples >= min_cohort_size
219221 sample_indices = group_samples_by_cohort .indices [cohort_key ]
220222 loc_coh = [i in sample_indices for i in range (0 , gt .shape [1 ])]
221- gt_coh = np .compress (loc_coh , gt , axis = 1 )
222- for i in range (0 , n_samples ):
223- for j in range (0 , 2 ):
224- gt_cont = np .ascontiguousarray (gt_coh [:, i , j ])
225- hap_hash = str (sha1 (gt_cont ).digest ())
226- if hap_hash not in count_rows .keys ():
227- count_rows [hap_hash ] = 1
228- else :
229- count_rows [hap_hash ] += 1
230- freq_rows = {k : i / (2 * n_samples ) for k , i in count_rows .items ()}
231- count_cols ["count_" + cohort_key_str ] = list (count_rows .values ())
232- freq_cols ["frq_" + cohort_key_str ] = list (freq_rows .values ())
233- nobs_cols ["nobs_" + cohort_key_str ] = [2 * n_samples ] * len (freq_rows )
223+ gt_coh = allel .GenotypeDaskArray (da .compress (loc_coh , gt , axis = 1 ))
224+ gt_hap = gt_coh .to_haplotypes ().compute ()
225+ f , c , o = haplotype_frequencies (gt_hap )
226+ hap_freq .update (f )
227+ hap_count .update (c )
228+ hap_nob .update (o )
229+ count_cols ["count_" + cohort_key_str ] = list (hap_count .values ())
230+ freq_cols ["frq_" + cohort_key_str ] = list (hap_freq .values ())
231+ nobs_cols ["nobs_" + cohort_key_str ] = list (hap_nob .values ())
234232
235233 n_haps = np .max ([len (i ) for i in freq_cols .values ()])
236234 freq_cols = {
237235 k : v + [0 for i in range (0 , n_haps - len (v ))] for k , v in freq_cols .items ()
238236 }
239- df_freqs = pd .DataFrame (freq_cols , index = freq_rows .keys ())
237+ df_freqs = pd .DataFrame (freq_cols , index = hap_freq .keys ())
240238
241239 # Compute max_af.
242240 df_max_af = pd .DataFrame ({"max_af" : df_freqs .max (axis = 1 )})
243241
244242 count_cols = {
245243 k : v + [0 for i in range (0 , n_haps - len (v ))] for k , v in count_cols .items ()
246244 }
247- df_counts = pd .DataFrame (count_cols , index = freq_rows .keys ())
245+ df_counts = pd .DataFrame (count_cols , index = hap_count .keys ())
248246
249247 nobs_cols = {
250248 k : v + [0 for i in range (0 , n_haps - len (v ))] for k , v in nobs_cols .items ()
251249 }
252- df_nobs = pd .DataFrame (nobs_cols , index = freq_rows .keys ())
250+ df_nobs = pd .DataFrame (nobs_cols , index = hap_nob .keys ())
253251
254252 # Build the final dataframe.
255253 df_haps = pd .concat ([df_freqs , df_counts , df_nobs , df_max_af ], axis = 1 )
0 commit comments