@@ -777,6 +777,356 @@ def cut_dist_tree(
777777
778778 return df
779779
780+ def _compute_haplotype_sharing (
781+ self ,
782+ * ,
783+ region ,
784+ analysis ,
785+ sample_sets ,
786+ sample_query ,
787+ sample_query_options ,
788+ cohort_size ,
789+ random_seed ,
790+ cohort_col ,
791+ chunks ,
792+ inline_array ,
793+ ):
794+ ds_haps = self .haplotypes (
795+ region = region ,
796+ analysis = analysis ,
797+ sample_sets = sample_sets ,
798+ sample_query = sample_query ,
799+ sample_query_options = sample_query_options ,
800+ cohort_size = cohort_size ,
801+ random_seed = random_seed ,
802+ chunks = chunks ,
803+ inline_array = inline_array ,
804+ )
805+
806+ df_samples = self .sample_metadata (
807+ sample_sets = sample_sets ,
808+ sample_query = sample_query ,
809+ sample_query_options = sample_query_options ,
810+ )
811+
812+ samples_phased = ds_haps ["sample_id" ].values
813+ df_samples_phased = (
814+ df_samples .set_index ("sample_id" ).loc [samples_phased ].reset_index ()
815+ )
816+ df_haps = df_samples_phased .loc [df_samples_phased .index .repeat (2 )].reset_index (
817+ drop = True
818+ )
819+
820+ if cohort_col not in df_haps .columns :
821+ raise ValueError (
822+ f"Column '{ cohort_col } ' not found in sample metadata. "
823+ f"Available columns: { ', ' .join (df_haps .columns )} "
824+ )
825+
826+ gt = allel .GenotypeDaskArray (ds_haps ["call_genotype" ].data )
827+ with self ._dask_progress (desc = "Load haplotypes" ):
828+ ht = gt .to_haplotypes ().compute ()
829+
830+ ac = ht .count_alleles (max_allele = 1 )
831+ ht_seg = ht [ac .is_segregating ()]
832+
833+ ht_distinct_sets = ht_seg .distinct ()
834+
835+ cohort_labels = df_haps [cohort_col ].values
836+ cohorts = pd .unique (df_haps [cohort_col ].dropna ())
837+
838+ sharing_matrix = pd .DataFrame (0 , index = cohorts , columns = cohorts , dtype = int )
839+
840+ for s in ht_distinct_sets :
841+ indices = list (s )
842+ cohorts_in_set = set (cohort_labels [indices ])
843+ cohorts_in_set .discard (None )
844+ cohorts_in_set = [c for c in cohorts_in_set if pd .notna (c )]
845+ for i , c1 in enumerate (cohorts_in_set ):
846+ for c2 in cohorts_in_set [i + 1 :]:
847+ sharing_matrix .loc [c1 , c2 ] += 1
848+ sharing_matrix .loc [c2 , c1 ] += 1
849+
850+ return sharing_matrix , cohorts
851+
852+ @_check_types
853+ @doc (
854+ summary = """
855+ Plot an arc diagram showing haplotype sharing between cohorts.
856+ """ ,
857+ )
858+ def plot_haplotype_sharing_arc (
859+ self ,
860+ region : base_params .regions ,
861+ cohort_col : hapclust_params .cohort_col ,
862+ analysis : hap_params .analysis = base_params .DEFAULT ,
863+ sample_sets : Optional [base_params .sample_sets ] = None ,
864+ sample_query : Optional [base_params .sample_query ] = None ,
865+ sample_query_options : Optional [base_params .sample_query_options ] = None ,
866+ cohort_size : Optional [base_params .cohort_size ] = None ,
867+ random_seed : base_params .random_seed = 42 ,
868+ title : plotly_params .title = True ,
869+ width : plotly_params .fig_width = None ,
870+ height : plotly_params .fig_height = 500 ,
871+ show : plotly_params .show = True ,
872+ renderer : plotly_params .renderer = None ,
873+ chunks : base_params .chunks = base_params .native_chunks ,
874+ inline_array : base_params .inline_array = base_params .inline_array_default ,
875+ ) -> plotly_params .figure :
876+ import plotly .graph_objects as go
877+ import plotly .express as px
878+
879+ sharing_matrix , cohorts = self ._compute_haplotype_sharing (
880+ region = region ,
881+ analysis = analysis ,
882+ sample_sets = sample_sets ,
883+ sample_query = sample_query ,
884+ sample_query_options = sample_query_options ,
885+ cohort_size = cohort_size ,
886+ random_seed = random_seed ,
887+ cohort_col = cohort_col ,
888+ chunks = chunks ,
889+ inline_array = inline_array ,
890+ )
891+
892+ n_cohorts = len (cohorts )
893+ cohort_list = list (cohorts )
894+ x_positions = np .linspace (0 , 1 , n_cohorts )
895+ cohort_x = {c : x for c , x in zip (cohort_list , x_positions )}
896+
897+ palette = px .colors .qualitative .Dark24
898+ cohort_colors = {
899+ c : palette [i % len (palette )] for i , c in enumerate (cohort_list )
900+ }
901+
902+ fig = go .Figure ()
903+
904+ max_sharing = sharing_matrix .values .max ()
905+ if max_sharing == 0 :
906+ max_sharing = 1
907+
908+ for i in range (n_cohorts ):
909+ for j in range (i + 1 , n_cohorts ):
910+ c1 = cohort_list [i ]
911+ c2 = cohort_list [j ]
912+ count = sharing_matrix .loc [c1 , c2 ]
913+ if count == 0 :
914+ continue
915+
916+ x1 = cohort_x [c1 ]
917+ x2 = cohort_x [c2 ]
918+ arc_height = abs (x2 - x1 ) * 0.5
919+
920+ line_width = 1 + (count / max_sharing ) * 9
921+
922+ t = np .linspace (0 , np .pi , 50 )
923+ arc_x = x1 + (x2 - x1 ) * (1 - np .cos (t )) / 2
924+ arc_y = arc_height * np .sin (t )
925+
926+ fig .add_trace (
927+ go .Scatter (
928+ x = arc_x ,
929+ y = arc_y ,
930+ mode = "lines" ,
931+ line = dict (width = line_width , color = cohort_colors [c1 ]),
932+ hoverinfo = "text" ,
933+ text = f"{ c1 } ↔ { c2 } : { count } shared" ,
934+ showlegend = False ,
935+ )
936+ )
937+
938+ fig .add_trace (
939+ go .Scatter (
940+ x = x_positions ,
941+ y = np .zeros (n_cohorts ),
942+ mode = "markers+text" ,
943+ marker = dict (size = 12 , color = [cohort_colors [c ] for c in cohort_list ]),
944+ text = cohort_list ,
945+ textposition = "bottom center" ,
946+ textfont = dict (size = 10 ),
947+ hoverinfo = "text" ,
948+ hovertext = cohort_list ,
949+ showlegend = False ,
950+ )
951+ )
952+
953+ if title is True :
954+ title_lines = []
955+ if sample_sets is not None :
956+ title_lines .append (f"Sample sets: { sample_sets } " )
957+ if sample_query is not None :
958+ title_lines .append (f"Sample query: { sample_query } " )
959+ title_lines .append (f"Genomic region: { region } " )
960+ title_lines .append (f"Cohorts grouped by: { cohort_col } " )
961+ title = "<br>" .join (title_lines )
962+
963+ fig .update_layout (
964+ title = title ,
965+ xaxis = dict (
966+ showticklabels = False ,
967+ showgrid = False ,
968+ zeroline = False ,
969+ ),
970+ yaxis = dict (
971+ showticklabels = False ,
972+ showgrid = False ,
973+ zeroline = False ,
974+ ),
975+ width = width or 800 ,
976+ height = height ,
977+ plot_bgcolor = "white" ,
978+ )
979+
980+ if show : # pragma: no cover
981+ fig .show (renderer = renderer )
982+ return None
983+ else :
984+ return fig
985+
986+ @_check_types
987+ @doc (
988+ summary = """
989+ Plot a chord diagram showing haplotype sharing between cohorts.
990+ """ ,
991+ )
992+ def plot_haplotype_sharing_chord (
993+ self ,
994+ region : base_params .regions ,
995+ cohort_col : hapclust_params .cohort_col ,
996+ analysis : hap_params .analysis = base_params .DEFAULT ,
997+ sample_sets : Optional [base_params .sample_sets ] = None ,
998+ sample_query : Optional [base_params .sample_query ] = None ,
999+ sample_query_options : Optional [base_params .sample_query_options ] = None ,
1000+ cohort_size : Optional [base_params .cohort_size ] = None ,
1001+ random_seed : base_params .random_seed = 42 ,
1002+ title : plotly_params .title = True ,
1003+ width : plotly_params .fig_width = 600 ,
1004+ height : plotly_params .fig_height = 600 ,
1005+ show : plotly_params .show = True ,
1006+ renderer : plotly_params .renderer = None ,
1007+ chunks : base_params .chunks = base_params .native_chunks ,
1008+ inline_array : base_params .inline_array = base_params .inline_array_default ,
1009+ ) -> plotly_params .figure :
1010+ import plotly .graph_objects as go
1011+ import plotly .express as px
1012+
1013+ sharing_matrix , cohorts = self ._compute_haplotype_sharing (
1014+ region = region ,
1015+ analysis = analysis ,
1016+ sample_sets = sample_sets ,
1017+ sample_query = sample_query ,
1018+ sample_query_options = sample_query_options ,
1019+ cohort_size = cohort_size ,
1020+ random_seed = random_seed ,
1021+ cohort_col = cohort_col ,
1022+ chunks = chunks ,
1023+ inline_array = inline_array ,
1024+ )
1025+
1026+ n_cohorts = len (cohorts )
1027+ cohort_list = list (cohorts )
1028+
1029+ angles = np .linspace (0 , 2 * np .pi , n_cohorts , endpoint = False )
1030+ radius = 1.0
1031+ cohort_x = {c : radius * np .cos (a ) for c , a in zip (cohort_list , angles )}
1032+ cohort_y = {c : radius * np .sin (a ) for c , a in zip (cohort_list , angles )}
1033+
1034+ palette = px .colors .qualitative .Dark24
1035+ cohort_colors = {
1036+ c : palette [i % len (palette )] for i , c in enumerate (cohort_list )
1037+ }
1038+
1039+ fig = go .Figure ()
1040+
1041+ max_sharing = sharing_matrix .values .max ()
1042+ if max_sharing == 0 :
1043+ max_sharing = 1
1044+
1045+ for i in range (n_cohorts ):
1046+ for j in range (i + 1 , n_cohorts ):
1047+ c1 = cohort_list [i ]
1048+ c2 = cohort_list [j ]
1049+ count = sharing_matrix .loc [c1 , c2 ]
1050+ if count == 0 :
1051+ continue
1052+
1053+ x1 , y1 = cohort_x [c1 ], cohort_y [c1 ]
1054+ x2 , y2 = cohort_x [c2 ], cohort_y [c2 ]
1055+
1056+ line_width = 1 + (count / max_sharing ) * 9
1057+
1058+ t = np .linspace (0 , 1 , 50 )
1059+ cx , cy = 0 , 0
1060+ chord_x = (1 - t ) ** 2 * x1 + 2 * (1 - t ) * t * cx + t ** 2 * x2
1061+ chord_y = (1 - t ) ** 2 * y1 + 2 * (1 - t ) * t * cy + t ** 2 * y2
1062+
1063+ fig .add_trace (
1064+ go .Scatter (
1065+ x = chord_x ,
1066+ y = chord_y ,
1067+ mode = "lines" ,
1068+ line = dict (width = line_width , color = cohort_colors [c1 ]),
1069+ opacity = 0.6 ,
1070+ hoverinfo = "text" ,
1071+ text = f"{ c1 } ↔ { c2 } : { count } shared" ,
1072+ showlegend = False ,
1073+ )
1074+ )
1075+
1076+ label_x = [cohort_x [c ] for c in cohort_list ]
1077+ label_y = [cohort_y [c ] for c in cohort_list ]
1078+
1079+ fig .add_trace (
1080+ go .Scatter (
1081+ x = label_x ,
1082+ y = label_y ,
1083+ mode = "markers+text" ,
1084+ marker = dict (size = 14 , color = [cohort_colors [c ] for c in cohort_list ]),
1085+ text = cohort_list ,
1086+ textposition = [
1087+ "top center" if y >= 0 else "bottom center" for y in label_y
1088+ ],
1089+ textfont = dict (size = 10 ),
1090+ hoverinfo = "text" ,
1091+ hovertext = cohort_list ,
1092+ showlegend = False ,
1093+ )
1094+ )
1095+
1096+ if title is True :
1097+ title_lines = []
1098+ if sample_sets is not None :
1099+ title_lines .append (f"Sample sets: { sample_sets } " )
1100+ if sample_query is not None :
1101+ title_lines .append (f"Sample query: { sample_query } " )
1102+ title_lines .append (f"Genomic region: { region } " )
1103+ title_lines .append (f"Cohorts grouped by: { cohort_col } " )
1104+ title = "<br>" .join (title_lines )
1105+
1106+ fig .update_layout (
1107+ title = title ,
1108+ xaxis = dict (
1109+ showticklabels = False ,
1110+ showgrid = False ,
1111+ zeroline = False ,
1112+ scaleanchor = "y" ,
1113+ ),
1114+ yaxis = dict (
1115+ showticklabels = False ,
1116+ showgrid = False ,
1117+ zeroline = False ,
1118+ ),
1119+ width = width ,
1120+ height = height ,
1121+ plot_bgcolor = "white" ,
1122+ )
1123+
1124+ if show : # pragma: no cover
1125+ fig .show (renderer = renderer )
1126+ return None
1127+ else :
1128+ return fig
1129+
7801130
7811131def _filter_and_remap (arr , x ):
7821132 from collections import Counter
0 commit comments