@@ -468,22 +468,50 @@ def plot_haplotype_clustering_advanced(
468468 figures .append (snp_trace )
469469 subplot_heights .append (25 )
470470
471- figures , subplot_heights , df_haps = self ._insert_hapclust_snp_trace (
472- transcript = snp_transcript ,
473- snp_query = snp_query ,
474- figures = figures ,
475- subplot_heights = subplot_heights ,
476- sample_sets = sample_sets ,
477- sample_query = sample_query ,
478- analysis = analysis ,
479- dendro_sample_id_order = dendro_sample_id_order ,
480- snp_filter_min_maf = snp_filter_min_maf ,
481- snp_colorscale = snp_colorscale ,
482- snp_row_height = snp_row_height ,
483- chunks = chunks ,
484- inline_array = inline_array ,
485- )
486- n_aa = df_haps .shape [0 ]
471+ n_snps_transcripts = []
472+ if isinstance (snp_transcript , str ):
473+ (
474+ figures ,
475+ subplot_heights ,
476+ n_snps_transcript ,
477+ ) = self ._insert_hapclust_snp_trace (
478+ transcript = snp_transcript ,
479+ snp_query = snp_query ,
480+ figures = figures ,
481+ subplot_heights = subplot_heights ,
482+ sample_sets = sample_sets ,
483+ sample_query = sample_query ,
484+ analysis = analysis ,
485+ dendro_sample_id_order = dendro_sample_id_order ,
486+ snp_filter_min_maf = snp_filter_min_maf ,
487+ snp_colorscale = snp_colorscale ,
488+ snp_row_height = snp_row_height ,
489+ chunks = chunks ,
490+ inline_array = inline_array ,
491+ )
492+ n_snps_transcripts .append (n_snps_transcript )
493+ elif isinstance (snp_transcript , list ):
494+ for st in snp_transcript :
495+ (
496+ figures ,
497+ subplot_heights ,
498+ n_snps_transcript ,
499+ ) = self ._insert_hapclust_snp_trace (
500+ transcript = st ,
501+ snp_query = snp_query ,
502+ figures = figures ,
503+ subplot_heights = subplot_heights ,
504+ sample_sets = sample_sets ,
505+ sample_query = sample_query ,
506+ analysis = analysis ,
507+ dendro_sample_id_order = dendro_sample_id_order ,
508+ snp_filter_min_maf = snp_filter_min_maf ,
509+ snp_colorscale = snp_colorscale ,
510+ snp_row_height = snp_row_height ,
511+ chunks = chunks ,
512+ inline_array = inline_array ,
513+ )
514+ n_snps_transcripts .append (n_snps_transcript )
487515
488516 # Calculate total height based on subplot heights, plus a fixed
489517 # additional component to allow for title, axes etc.
@@ -499,45 +527,57 @@ def plot_haplotype_clustering_advanced(
499527 n_snps = n_snps_cluster ,
500528 )
501529
502- fig ["layout" ]["yaxis" ]["title" ] = "Distance"
530+ fig ["layout" ]["yaxis" ]["title" ] = f "Distance ( { distance_metric } ) "
503531 fig .update_layout (
504532 title_font = dict (
505533 size = title_font_size ,
506534 ),
507535 legend = dict (itemsizing = legend_sizing , tracegroupgap = 0 ),
508536 )
509537
510- if snp_transcript and n_aa > 0 :
511- # add lines to aa plot
512- aa_idx = len (figures )
513- fig .add_hline (y = - 0.5 , line_width = 1 , line_color = "grey" , row = aa_idx , col = 1 )
514- for i in range (n_aa ):
515- fig .add_hline (
516- y = i + 0.5 , line_width = 1 , line_color = "grey" , row = aa_idx , col = 1
538+ # add lines to aa plot - looks neater
539+ if snp_transcript :
540+ n_transcripts = (
541+ len (snp_transcript ) if isinstance (snp_transcript , list ) else 1
542+ )
543+ for i in range (n_transcripts ):
544+ tx_idx = len (figures ) - n_transcripts + i + 1
545+ if n_snps_transcripts [i ] > 0 :
546+ fig .add_hline (
547+ y = - 0.5 , line_width = 1 , line_color = "grey" , row = tx_idx , col = 1
548+ )
549+ for j in range (n_snps_transcripts [i ]):
550+ fig .add_hline (
551+ y = j + 0.5 ,
552+ line_width = 1 ,
553+ line_color = "grey" ,
554+ row = tx_idx ,
555+ col = 1 ,
556+ )
557+
558+ fig .update_xaxes (
559+ showline = True ,
560+ linecolor = "grey" ,
561+ linewidth = 1 ,
562+ row = tx_idx ,
563+ col = 1 ,
564+ mirror = True ,
517565 )
518566
519- fig .update_xaxes (
520- showline = True ,
521- linecolor = "grey" ,
522- linewidth = 1 ,
523- row = aa_idx ,
524- col = 1 ,
525- mirror = True ,
526- )
527- fig .update_yaxes (
528- showline = True ,
529- linecolor = "grey" ,
530- linewidth = 1 ,
531- row = aa_idx ,
532- col = 1 ,
533- mirror = True ,
534- )
567+ fig .update_yaxes (
568+ showline = True ,
569+ linecolor = "grey" ,
570+ linewidth = 1 ,
571+ row = tx_idx ,
572+ col = 1 ,
573+ mirror = True ,
574+ )
535575
536576 if show :
537577 fig .show (renderer = renderer )
538578 return None
539579 else :
540- return fig , leaf_data , df_haps
580+ return fig , leaf_data
541581
542582 def transcript_haplotypes (
543583 self ,
@@ -642,6 +682,8 @@ def _insert_hapclust_snp_trace(
642682 df_haps = df_haps .assign (af = lambda x : x .sum (axis = 1 ) / x .shape [1 ])
643683 df_haps = df_haps .query ("af > @snp_filter_min_maf" ).drop (columns = "af" )
644684
685+ n_snps_transcript = df_haps .shape [0 ]
686+
645687 if not df_haps .empty :
646688 snp_trace = go .Heatmap (
647689 z = df_haps .values ,
@@ -660,7 +702,7 @@ def _insert_hapclust_snp_trace(
660702 print (
661703 f"No SNPs were found below { snp_filter_min_maf } allele frequency. Omitting SNP genotype plot."
662704 )
663- return figures , subplot_heights , df_haps
705+ return figures , subplot_heights , n_snps_transcript
664706
665707 def cut_dist_tree (
666708 self ,
0 commit comments