@@ -790,6 +790,7 @@ def _compute_haplotype_sharing(
790790 cohort_col ,
791791 chunks ,
792792 inline_array ,
793+ metric = "unique" ,
793794 ):
794795 """
795796 Computes the number of identical haplotypes shared between cohorts
@@ -852,14 +853,26 @@ def _compute_haplotype_sharing(
852853
853854 for s in ht_distinct_sets :
854855 indices = list (s )
855- cohorts_in_set = set (cohort_labels [indices ])
856- cohorts_in_set .discard (None )
857- cohorts_in_set = [c for c in cohorts_in_set if pd .notna (c )]
858- # Increment the symmetric matrix for every unique pair of cohorts in this group
859- for i , c1 in enumerate (cohorts_in_set ):
860- for c2 in cohorts_in_set [i + 1 :]:
861- sharing_matrix .loc [c1 , c2 ] += 1
862- sharing_matrix .loc [c2 , c1 ] += 1
856+ labels_in_group = cohort_labels [indices ]
857+
858+ if metric == "unique" :
859+ cohorts_in_set = set (labels_in_group )
860+ cohorts_in_set .discard (None )
861+ cohorts_in_set = [c for c in cohorts_in_set if pd .notna (c )]
862+ for i , c1 in enumerate (cohorts_in_set ):
863+ for c2 in cohorts_in_set [i + 1 :]:
864+ sharing_matrix .loc [c1 , c2 ] += 1
865+ sharing_matrix .loc [c2 , c1 ] += 1
866+ else :
867+ from collections import Counter
868+
869+ counts = Counter (labels_in_group )
870+ unique_cohorts = [c for c in counts .keys () if pd .notna (c )]
871+ for i , c1 in enumerate (unique_cohorts ):
872+ for c2 in unique_cohorts [i + 1 :]:
873+ shared = counts [c1 ] * counts [c2 ]
874+ sharing_matrix .loc [c1 , c2 ] += shared
875+ sharing_matrix .loc [c2 , c1 ] += shared
863876
864877 return sharing_matrix , cohorts
865878
@@ -886,6 +899,7 @@ def plot_haplotype_sharing_arc(
886899 renderer : plotly_params .renderer = None ,
887900 chunks : base_params .chunks = base_params .native_chunks ,
888901 inline_array : base_params .inline_array = base_params .inline_array_default ,
902+ metric : str = "unique" ,
889903 ) -> plotly_params .figure :
890904 import plotly .graph_objects as go
891905 import plotly .express as px
@@ -901,6 +915,7 @@ def plot_haplotype_sharing_arc(
901915 cohort_col = cohort_col ,
902916 chunks = chunks ,
903917 inline_array = inline_array ,
918+ metric = metric ,
904919 )
905920
906921 n_cohorts = len (cohorts )
@@ -1021,6 +1036,7 @@ def plot_haplotype_sharing_chord(
10211036 renderer : plotly_params .renderer = None ,
10221037 chunks : base_params .chunks = base_params .native_chunks ,
10231038 inline_array : base_params .inline_array = base_params .inline_array_default ,
1039+ metric : str = "unique" ,
10241040 ) -> plotly_params .figure :
10251041 import plotly .graph_objects as go
10261042 import plotly .express as px
@@ -1036,6 +1052,7 @@ def plot_haplotype_sharing_chord(
10361052 cohort_col = cohort_col ,
10371053 chunks = chunks ,
10381054 inline_array = inline_array ,
1055+ metric = metric ,
10391056 )
10401057
10411058 n_cohorts = len (cohorts )
0 commit comments