@@ -266,8 +266,16 @@ def _h12_gwss(
266266 # Compute window midpoints.
267267 pos = ds_haps ["variant_position" ].values
268268 x = allel .moving_statistic (pos , statistic = np .mean , size = window_size )
269+ contigs = np .asarray (
270+ allel .moving_statistic (
271+ ds_haps ["variant_contig" ].values ,
272+ statistic = np .median ,
273+ size = window_size ,
274+ ),
275+ dtype = int ,
276+ )
269277
270- results = dict (x = x , h12 = h12 )
278+ results = dict (x = x , h12 = h12 , contigs = contigs )
271279
272280 return results
273281
@@ -277,6 +285,7 @@ def _h12_gwss(
277285 returns = dict (
278286 x = "An array containing the window centre point genomic positions." ,
279287 h12 = "An array with h12 statistic values for each window." ,
288+ contigs = "An array with the contig for each window. The median is chosen for windows overlapping a change of contig." ,
280289 ),
281290 )
282291 def h12_gwss (
@@ -297,10 +306,10 @@ def h12_gwss(
297306 random_seed : base_params .random_seed = 42 ,
298307 chunks : base_params .chunks = base_params .native_chunks ,
299308 inline_array : base_params .inline_array = base_params .inline_array_default ,
300- ) -> Tuple [np .ndarray , np .ndarray ]:
309+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray ]:
301310 # Change this name if you ever change the behaviour of this function, to
302311 # invalidate any previously cached data.
303- name = "h12_gwss_v1 "
312+ name = "h12_gwss_contig_v1 "
304313
305314 params = dict (
306315 contig = contig ,
@@ -327,8 +336,9 @@ def h12_gwss(
327336
328337 x = results ["x" ]
329338 h12 = results ["h12" ]
339+ contigs = results ["contigs" ]
330340
331- return x , h12
341+ return x , h12 , contigs
332342
333343 @check_types
334344 @doc (
@@ -354,14 +364,15 @@ def plot_h12_gwss_track(
354364 sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
355365 width : gplt_params .width = gplt_params .width_default ,
356366 height : gplt_params .height = 200 ,
367+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
357368 show : gplt_params .show = True ,
358369 x_range : Optional [gplt_params .x_range ] = None ,
359370 output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
360371 chunks : base_params .chunks = base_params .native_chunks ,
361372 inline_array : base_params .inline_array = base_params .inline_array_default ,
362373 ) -> gplt_params .figure :
363374 # Compute H12.
364- x , h12 = self .h12_gwss (
375+ x , h12 , contigs = self .h12_gwss (
365376 contig = contig ,
366377 analysis = analysis ,
367378 window_size = window_size ,
@@ -412,15 +423,14 @@ def plot_h12_gwss_track(
412423 )
413424
414425 # Plot H12.
415- fig .scatter (
416- x = x ,
417- y = h12 ,
418- marker = "circle" ,
419- size = 3 ,
420- line_width = 1 ,
421- line_color = "black" ,
422- fill_color = None ,
423- )
426+ for s in set (contigs ):
427+ idxs = contigs == s
428+ fig .scatter (
429+ x = x [idxs ],
430+ y = h12 [idxs ],
431+ marker = "circle" ,
432+ color = contig_colors [s % len (contig_colors )],
433+ )
424434
425435 # Tidy up the plot.
426436 fig .yaxis .axis_label = "H12"
@@ -457,6 +467,7 @@ def plot_h12_gwss(
457467 sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
458468 width : gplt_params .width = gplt_params .width_default ,
459469 track_height : gplt_params .track_height = 170 ,
470+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
460471 genes_height : gplt_params .genes_height = gplt_params .genes_height_default ,
461472 show : gplt_params .show = True ,
462473 output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
@@ -479,6 +490,7 @@ def plot_h12_gwss(
479490 sizing_mode = sizing_mode ,
480491 width = width ,
481492 height = track_height ,
493+ contig_colors = contig_colors ,
482494 show = False ,
483495 output_backend = output_backend ,
484496 chunks = chunks ,
@@ -575,7 +587,7 @@ def plot_h12_gwss_multi_overlay_track(
575587 )
576588
577589 # Determine X axis range.
578- x , _ = res [list (cohort_queries .keys ())[0 ]]
590+ x , _ , _ = res [list (cohort_queries .keys ())[0 ]]
579591 x_min = x [0 ]
580592 x_max = x [- 1 ]
581593 if x_range is None :
@@ -610,7 +622,7 @@ def plot_h12_gwss_multi_overlay_track(
610622 )
611623
612624 # Plot H12.
613- for i , (cohort_label , (x , h12 )) in enumerate (res .items ()):
625+ for i , (cohort_label , (x , h12 , contig )) in enumerate (res .items ()):
614626 fig .scatter (
615627 x = x ,
616628 y = h12 ,
0 commit comments