Skip to content

Commit 319df58

Browse files
authored
Improve neighbour-joining tree performance (#641)
* begin njt refactoring * refactor plot_njt * fix docs * tweak * add tests * fix nan comparisons * poetry update * consistent color for NA * fix karyotype categorical * fix warning * fix haplotype networks
1 parent c2d4752 commit 319df58

15 files changed

Lines changed: 1816 additions & 40923 deletions

File tree

docs/source/Af1.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ Genetic distance and neighbour-joining trees (NJT)
134134
:toctree: generated/
135135

136136
plot_njt
137+
njt
137138
biallelic_diplotype_pairwise_distances
138139

139140
Heterozygosity analysis

docs/source/Ag3.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ Genetic distance and neighbour-joining trees (NJT)
144144
:toctree: generated/
145145

146146
plot_njt
147+
njt
147148
biallelic_diplotype_pairwise_distances
148149

149150
Heterozygosity analysis

malariagen_data/ag3.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,17 +392,20 @@ def karyotype(
392392
ds_snps = self.snp_calls(
393393
region=region, sample_sets=sample_sets, sample_query=sample_query
394394
)
395-
geno = allel.GenotypeDaskArray(ds_snps["call_genotype"].data)
396-
pos = allel.SortedIndex(ds_snps["variant_position"].values)
397-
samples = ds_snps["sample_id"].values
398-
alts = ds_snps["variant_allele"].values.astype(str)
399-
400-
# subset to position of inversion tags
401-
mask = pos.locate_intersection(inversion_pos)[0]
402-
alts = alts[mask]
403-
geno = geno.compress(mask, axis=0).compute()
404395

405396
with self._spinner("Inferring karyotype from tag SNPs"):
397+
# access variables we need
398+
geno = allel.GenotypeDaskArray(ds_snps["call_genotype"].data)
399+
pos = allel.SortedIndex(ds_snps["variant_position"].values)
400+
samples = ds_snps["sample_id"].values
401+
alts = ds_snps["variant_allele"].values.astype(str)
402+
403+
# subset to position of inversion tags
404+
mask = pos.locate_intersection(inversion_pos)[0]
405+
alts = alts[mask]
406+
geno = geno.compress(mask, axis=0).compute()
407+
408+
# infer karyotype
406409
gn_alt = _karyotype_tags_n_alt(
407410
gt=geno, alts=alts, inversion_alts=inversion_alts
408411
)
@@ -422,7 +425,8 @@ def karyotype(
422425
"total_tag_snps": total_sites,
423426
},
424427
)
425-
kt_dtype = CategoricalDtype(categories=[0, 1, 2], ordered=True)
428+
# Allow filling missing values with "<NA>" visible placeholder.
429+
kt_dtype = CategoricalDtype(categories=[0, 1, 2, "<NA>"], ordered=True)
426430
df[f"karyotype_{inversion}"] = df[f"karyotype_{inversion}"].astype(kt_dtype)
427431

428432
return df

malariagen_data/anoph/dipclust_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Parameters for diplotype clustering functions."""
22

3-
from .diplotype_distance_params import distance_metric
3+
from .distance_params import distance_metric
44
from .clustering_params import linkage_method
55

66

0 commit comments

Comments
 (0)