Skip to content

Commit bc2424c

Browse files
committed
feat: add arc and chord diagrams for haplotype sharing (#457, #458)
1 parent 515ed4f commit bc2424c

3 files changed

Lines changed: 382 additions & 0 deletions

File tree

malariagen_data/anoph/hapclust.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,356 @@ def cut_dist_tree(
777777

778778
return df
779779

780+
def _compute_haplotype_sharing(
781+
self,
782+
*,
783+
region,
784+
analysis,
785+
sample_sets,
786+
sample_query,
787+
sample_query_options,
788+
cohort_size,
789+
random_seed,
790+
cohort_col,
791+
chunks,
792+
inline_array,
793+
):
794+
ds_haps = self.haplotypes(
795+
region=region,
796+
analysis=analysis,
797+
sample_sets=sample_sets,
798+
sample_query=sample_query,
799+
sample_query_options=sample_query_options,
800+
cohort_size=cohort_size,
801+
random_seed=random_seed,
802+
chunks=chunks,
803+
inline_array=inline_array,
804+
)
805+
806+
df_samples = self.sample_metadata(
807+
sample_sets=sample_sets,
808+
sample_query=sample_query,
809+
sample_query_options=sample_query_options,
810+
)
811+
812+
samples_phased = ds_haps["sample_id"].values
813+
df_samples_phased = (
814+
df_samples.set_index("sample_id").loc[samples_phased].reset_index()
815+
)
816+
df_haps = df_samples_phased.loc[df_samples_phased.index.repeat(2)].reset_index(
817+
drop=True
818+
)
819+
820+
if cohort_col not in df_haps.columns:
821+
raise ValueError(
822+
f"Column '{cohort_col}' not found in sample metadata. "
823+
f"Available columns: {', '.join(df_haps.columns)}"
824+
)
825+
826+
gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data)
827+
with self._dask_progress(desc="Load haplotypes"):
828+
ht = gt.to_haplotypes().compute()
829+
830+
ac = ht.count_alleles(max_allele=1)
831+
ht_seg = ht[ac.is_segregating()]
832+
833+
ht_distinct_sets = ht_seg.distinct()
834+
835+
cohort_labels = df_haps[cohort_col].values
836+
cohorts = pd.unique(df_haps[cohort_col].dropna())
837+
838+
sharing_matrix = pd.DataFrame(0, index=cohorts, columns=cohorts, dtype=int)
839+
840+
for s in ht_distinct_sets:
841+
indices = list(s)
842+
cohorts_in_set = set(cohort_labels[indices])
843+
cohorts_in_set.discard(None)
844+
cohorts_in_set = [c for c in cohorts_in_set if pd.notna(c)]
845+
for i, c1 in enumerate(cohorts_in_set):
846+
for c2 in cohorts_in_set[i + 1 :]:
847+
sharing_matrix.loc[c1, c2] += 1
848+
sharing_matrix.loc[c2, c1] += 1
849+
850+
return sharing_matrix, cohorts
851+
852+
@_check_types
853+
@doc(
854+
summary="""
855+
Plot an arc diagram showing haplotype sharing between cohorts.
856+
""",
857+
)
858+
def plot_haplotype_sharing_arc(
859+
self,
860+
region: base_params.regions,
861+
cohort_col: hapclust_params.cohort_col,
862+
analysis: hap_params.analysis = base_params.DEFAULT,
863+
sample_sets: Optional[base_params.sample_sets] = None,
864+
sample_query: Optional[base_params.sample_query] = None,
865+
sample_query_options: Optional[base_params.sample_query_options] = None,
866+
cohort_size: Optional[base_params.cohort_size] = None,
867+
random_seed: base_params.random_seed = 42,
868+
title: plotly_params.title = True,
869+
width: plotly_params.fig_width = None,
870+
height: plotly_params.fig_height = 500,
871+
show: plotly_params.show = True,
872+
renderer: plotly_params.renderer = None,
873+
chunks: base_params.chunks = base_params.native_chunks,
874+
inline_array: base_params.inline_array = base_params.inline_array_default,
875+
) -> plotly_params.figure:
876+
import plotly.graph_objects as go
877+
import plotly.express as px
878+
879+
sharing_matrix, cohorts = self._compute_haplotype_sharing(
880+
region=region,
881+
analysis=analysis,
882+
sample_sets=sample_sets,
883+
sample_query=sample_query,
884+
sample_query_options=sample_query_options,
885+
cohort_size=cohort_size,
886+
random_seed=random_seed,
887+
cohort_col=cohort_col,
888+
chunks=chunks,
889+
inline_array=inline_array,
890+
)
891+
892+
n_cohorts = len(cohorts)
893+
cohort_list = list(cohorts)
894+
x_positions = np.linspace(0, 1, n_cohorts)
895+
cohort_x = {c: x for c, x in zip(cohort_list, x_positions)}
896+
897+
palette = px.colors.qualitative.Dark24
898+
cohort_colors = {
899+
c: palette[i % len(palette)] for i, c in enumerate(cohort_list)
900+
}
901+
902+
fig = go.Figure()
903+
904+
max_sharing = sharing_matrix.values.max()
905+
if max_sharing == 0:
906+
max_sharing = 1
907+
908+
for i in range(n_cohorts):
909+
for j in range(i + 1, n_cohorts):
910+
c1 = cohort_list[i]
911+
c2 = cohort_list[j]
912+
count = sharing_matrix.loc[c1, c2]
913+
if count == 0:
914+
continue
915+
916+
x1 = cohort_x[c1]
917+
x2 = cohort_x[c2]
918+
arc_height = abs(x2 - x1) * 0.5
919+
920+
line_width = 1 + (count / max_sharing) * 9
921+
922+
t = np.linspace(0, np.pi, 50)
923+
arc_x = x1 + (x2 - x1) * (1 - np.cos(t)) / 2
924+
arc_y = arc_height * np.sin(t)
925+
926+
fig.add_trace(
927+
go.Scatter(
928+
x=arc_x,
929+
y=arc_y,
930+
mode="lines",
931+
line=dict(width=line_width, color=cohort_colors[c1]),
932+
hoverinfo="text",
933+
text=f"{c1}{c2}: {count} shared",
934+
showlegend=False,
935+
)
936+
)
937+
938+
fig.add_trace(
939+
go.Scatter(
940+
x=x_positions,
941+
y=np.zeros(n_cohorts),
942+
mode="markers+text",
943+
marker=dict(size=12, color=[cohort_colors[c] for c in cohort_list]),
944+
text=cohort_list,
945+
textposition="bottom center",
946+
textfont=dict(size=10),
947+
hoverinfo="text",
948+
hovertext=cohort_list,
949+
showlegend=False,
950+
)
951+
)
952+
953+
if title is True:
954+
title_lines = []
955+
if sample_sets is not None:
956+
title_lines.append(f"Sample sets: {sample_sets}")
957+
if sample_query is not None:
958+
title_lines.append(f"Sample query: {sample_query}")
959+
title_lines.append(f"Genomic region: {region}")
960+
title_lines.append(f"Cohorts grouped by: {cohort_col}")
961+
title = "<br>".join(title_lines)
962+
963+
fig.update_layout(
964+
title=title,
965+
xaxis=dict(
966+
showticklabels=False,
967+
showgrid=False,
968+
zeroline=False,
969+
),
970+
yaxis=dict(
971+
showticklabels=False,
972+
showgrid=False,
973+
zeroline=False,
974+
),
975+
width=width or 800,
976+
height=height,
977+
plot_bgcolor="white",
978+
)
979+
980+
if show: # pragma: no cover
981+
fig.show(renderer=renderer)
982+
return None
983+
else:
984+
return fig
985+
986+
@_check_types
987+
@doc(
988+
summary="""
989+
Plot a chord diagram showing haplotype sharing between cohorts.
990+
""",
991+
)
992+
def plot_haplotype_sharing_chord(
993+
self,
994+
region: base_params.regions,
995+
cohort_col: hapclust_params.cohort_col,
996+
analysis: hap_params.analysis = base_params.DEFAULT,
997+
sample_sets: Optional[base_params.sample_sets] = None,
998+
sample_query: Optional[base_params.sample_query] = None,
999+
sample_query_options: Optional[base_params.sample_query_options] = None,
1000+
cohort_size: Optional[base_params.cohort_size] = None,
1001+
random_seed: base_params.random_seed = 42,
1002+
title: plotly_params.title = True,
1003+
width: plotly_params.fig_width = 600,
1004+
height: plotly_params.fig_height = 600,
1005+
show: plotly_params.show = True,
1006+
renderer: plotly_params.renderer = None,
1007+
chunks: base_params.chunks = base_params.native_chunks,
1008+
inline_array: base_params.inline_array = base_params.inline_array_default,
1009+
) -> plotly_params.figure:
1010+
import plotly.graph_objects as go
1011+
import plotly.express as px
1012+
1013+
sharing_matrix, cohorts = self._compute_haplotype_sharing(
1014+
region=region,
1015+
analysis=analysis,
1016+
sample_sets=sample_sets,
1017+
sample_query=sample_query,
1018+
sample_query_options=sample_query_options,
1019+
cohort_size=cohort_size,
1020+
random_seed=random_seed,
1021+
cohort_col=cohort_col,
1022+
chunks=chunks,
1023+
inline_array=inline_array,
1024+
)
1025+
1026+
n_cohorts = len(cohorts)
1027+
cohort_list = list(cohorts)
1028+
1029+
angles = np.linspace(0, 2 * np.pi, n_cohorts, endpoint=False)
1030+
radius = 1.0
1031+
cohort_x = {c: radius * np.cos(a) for c, a in zip(cohort_list, angles)}
1032+
cohort_y = {c: radius * np.sin(a) for c, a in zip(cohort_list, angles)}
1033+
1034+
palette = px.colors.qualitative.Dark24
1035+
cohort_colors = {
1036+
c: palette[i % len(palette)] for i, c in enumerate(cohort_list)
1037+
}
1038+
1039+
fig = go.Figure()
1040+
1041+
max_sharing = sharing_matrix.values.max()
1042+
if max_sharing == 0:
1043+
max_sharing = 1
1044+
1045+
for i in range(n_cohorts):
1046+
for j in range(i + 1, n_cohorts):
1047+
c1 = cohort_list[i]
1048+
c2 = cohort_list[j]
1049+
count = sharing_matrix.loc[c1, c2]
1050+
if count == 0:
1051+
continue
1052+
1053+
x1, y1 = cohort_x[c1], cohort_y[c1]
1054+
x2, y2 = cohort_x[c2], cohort_y[c2]
1055+
1056+
line_width = 1 + (count / max_sharing) * 9
1057+
1058+
t = np.linspace(0, 1, 50)
1059+
cx, cy = 0, 0
1060+
chord_x = (1 - t) ** 2 * x1 + 2 * (1 - t) * t * cx + t**2 * x2
1061+
chord_y = (1 - t) ** 2 * y1 + 2 * (1 - t) * t * cy + t**2 * y2
1062+
1063+
fig.add_trace(
1064+
go.Scatter(
1065+
x=chord_x,
1066+
y=chord_y,
1067+
mode="lines",
1068+
line=dict(width=line_width, color=cohort_colors[c1]),
1069+
opacity=0.6,
1070+
hoverinfo="text",
1071+
text=f"{c1}{c2}: {count} shared",
1072+
showlegend=False,
1073+
)
1074+
)
1075+
1076+
label_x = [cohort_x[c] for c in cohort_list]
1077+
label_y = [cohort_y[c] for c in cohort_list]
1078+
1079+
fig.add_trace(
1080+
go.Scatter(
1081+
x=label_x,
1082+
y=label_y,
1083+
mode="markers+text",
1084+
marker=dict(size=14, color=[cohort_colors[c] for c in cohort_list]),
1085+
text=cohort_list,
1086+
textposition=[
1087+
"top center" if y >= 0 else "bottom center" for y in label_y
1088+
],
1089+
textfont=dict(size=10),
1090+
hoverinfo="text",
1091+
hovertext=cohort_list,
1092+
showlegend=False,
1093+
)
1094+
)
1095+
1096+
if title is True:
1097+
title_lines = []
1098+
if sample_sets is not None:
1099+
title_lines.append(f"Sample sets: {sample_sets}")
1100+
if sample_query is not None:
1101+
title_lines.append(f"Sample query: {sample_query}")
1102+
title_lines.append(f"Genomic region: {region}")
1103+
title_lines.append(f"Cohorts grouped by: {cohort_col}")
1104+
title = "<br>".join(title_lines)
1105+
1106+
fig.update_layout(
1107+
title=title,
1108+
xaxis=dict(
1109+
showticklabels=False,
1110+
showgrid=False,
1111+
zeroline=False,
1112+
scaleanchor="y",
1113+
),
1114+
yaxis=dict(
1115+
showticklabels=False,
1116+
showgrid=False,
1117+
zeroline=False,
1118+
),
1119+
width=width,
1120+
height=height,
1121+
plot_bgcolor="white",
1122+
)
1123+
1124+
if show: # pragma: no cover
1125+
fig.show(renderer=renderer)
1126+
return None
1127+
else:
1128+
return fig
1129+
7801130

7811131
def _filter_and_remap(arr, x):
7821132
from collections import Counter

malariagen_data/anoph/hapclust_params.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,11 @@
1515
]
1616

1717
distance_metric_default: Literal["hamming", "dxy"] = "hamming"
18+
19+
cohort_col: TypeAlias = Annotated[
20+
str,
21+
"""
22+
Column name in sample metadata used to define cohorts for grouping,
23+
e.g., 'country', 'taxon', 'aim_species'.
24+
""",
25+
]

tests/anoph/test_hapclust.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,27 @@ def test_plot_haplotype_clustering(fixture, api: AnophelesHapClustAnalysis):
9696

9797
# Run checks.
9898
api.plot_haplotype_clustering(**hapclust_params)
99+
100+
101+
@parametrize_with_cases("fixture,api", cases=".")
102+
def test_plot_haplotype_sharing_arc(fixture, api: AnophelesHapClustAnalysis):
103+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
104+
fig = api.plot_haplotype_sharing_arc(
105+
region=fixture.random_region_str(region_size=5000),
106+
cohort_col="country",
107+
sample_sets=[random.choice(all_sample_sets)],
108+
show=False,
109+
)
110+
assert fig is not None
111+
112+
113+
@parametrize_with_cases("fixture,api", cases=".")
114+
def test_plot_haplotype_sharing_chord(fixture, api: AnophelesHapClustAnalysis):
115+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
116+
fig = api.plot_haplotype_sharing_chord(
117+
region=fixture.random_region_str(region_size=5000),
118+
cohort_col="country",
119+
sample_sets=[random.choice(all_sample_sets)],
120+
show=False,
121+
)
122+
assert fig is not None

0 commit comments

Comments
 (0)