@@ -265,8 +265,16 @@ def _h12_gwss(
265265 # Compute window midpoints.
266266 pos = ds_haps ["variant_position" ].values
267267 x = allel .moving_statistic (pos , statistic = np .mean , size = window_size )
268+ contigs = np .asarray (
269+ allel .moving_statistic (
270+ ds_haps ["variant_contig" ].values ,
271+ statistic = np .median ,
272+ size = window_size ,
273+ ),
274+ dtype = int ,
275+ )
268276
269- results = dict (x = x , h12 = h12 )
277+ results = dict (x = x , h12 = h12 , contigs = contigs )
270278
271279 return results
272280
@@ -276,6 +284,7 @@ def _h12_gwss(
276284 returns = dict (
277285 x = "An array containing the window centre point genomic positions." ,
278286 h12 = "An array with h12 statistic values for each window." ,
287+ contigs = "An array with the contig for each window. The median is chosen for windows overlapping a change of contig." ,
279288 ),
280289 )
281290 def h12_gwss (
@@ -296,10 +305,10 @@ def h12_gwss(
296305 random_seed : base_params .random_seed = 42 ,
297306 chunks : base_params .chunks = base_params .native_chunks ,
298307 inline_array : base_params .inline_array = base_params .inline_array_default ,
299- ) -> Tuple [np .ndarray , np .ndarray ]:
308+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray ]:
300309 # Change this name if you ever change the behaviour of this function, to
301310 # invalidate any previously cached data.
302- name = "h12_gwss_v1 "
311+ name = "h12_gwss_v2 "
303312
304313 params = dict (
305314 contig = contig ,
@@ -326,8 +335,9 @@ def h12_gwss(
326335
327336 x = results ["x" ]
328337 h12 = results ["h12" ]
338+ contigs = results ["contigs" ]
329339
330- return x , h12
340+ return x , h12 , contigs
331341
332342 @check_types
333343 @doc (
@@ -353,14 +363,15 @@ def plot_h12_gwss_track(
353363 sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
354364 width : gplt_params .width = gplt_params .width_default ,
355365 height : gplt_params .height = 200 ,
366+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
356367 show : gplt_params .show = True ,
357368 x_range : Optional [gplt_params .x_range ] = None ,
358369 output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
359370 chunks : base_params .chunks = base_params .native_chunks ,
360371 inline_array : base_params .inline_array = base_params .inline_array_default ,
361372 ) -> gplt_params .figure :
362373 # Compute H12.
363- x , h12 = self .h12_gwss (
374+ x , h12 , contigs = self .h12_gwss (
364375 contig = contig ,
365376 analysis = analysis ,
366377 window_size = window_size ,
@@ -411,15 +422,14 @@ def plot_h12_gwss_track(
411422 )
412423
413424 # Plot H12.
414- fig .scatter (
415- x = x ,
416- y = h12 ,
417- marker = "circle" ,
418- size = 3 ,
419- line_width = 1 ,
420- line_color = "black" ,
421- fill_color = None ,
422- )
425+ for s in set (contigs ):
426+ idxs = contigs == s
427+ fig .scatter (
428+ x = x [idxs ],
429+ y = h12 [idxs ],
430+ marker = "circle" ,
431+ color = contig_colors [s % len (contig_colors )],
432+ )
423433
424434 # Tidy up the plot.
425435 fig .yaxis .axis_label = "H12"
@@ -456,6 +466,7 @@ def plot_h12_gwss(
456466 sizing_mode : gplt_params .sizing_mode = gplt_params .sizing_mode_default ,
457467 width : gplt_params .width = gplt_params .width_default ,
458468 track_height : gplt_params .track_height = 170 ,
469+ contig_colors : gplt_params .contig_colors = gplt_params .contig_colors_default ,
459470 genes_height : gplt_params .genes_height = gplt_params .genes_height_default ,
460471 show : gplt_params .show = True ,
461472 output_backend : gplt_params .output_backend = gplt_params .output_backend_default ,
@@ -478,6 +489,7 @@ def plot_h12_gwss(
478489 sizing_mode = sizing_mode ,
479490 width = width ,
480491 height = track_height ,
492+ contig_colors = contig_colors ,
481493 show = False ,
482494 output_backend = output_backend ,
483495 chunks = chunks ,
@@ -574,7 +586,7 @@ def plot_h12_gwss_multi_overlay_track(
574586 )
575587
576588 # Determine X axis range.
577- x , _ = res [list (cohort_queries .keys ())[0 ]]
589+ x , _ , _ = res [list (cohort_queries .keys ())[0 ]]
578590 x_min = x [0 ]
579591 x_max = x [- 1 ]
580592 if x_range is None :
@@ -609,7 +621,7 @@ def plot_h12_gwss_multi_overlay_track(
609621 )
610622
611623 # Plot H12.
612- for i , (cohort_label , (x , h12 )) in enumerate (res .items ()):
624+ for i , (cohort_label , (x , h12 , contig )) in enumerate (res .items ()):
613625 fig .scatter (
614626 x = x ,
615627 y = h12 ,
0 commit comments