@@ -256,10 +256,10 @@ def cnv_hmm(
256256 # Determine which samples match the sample query.
257257 loc_samples = df_samples .eval (
258258 prepared_sample_query , ** sample_query_options
259- ). values
259+ )
260260
261261 # Raise an error if no samples match the sample query.
262- if np . count_nonzero ( loc_samples ) == 0 :
262+ if not loc_samples . any () :
263263 raise ValueError (
264264 f"No samples found for query { prepared_sample_query !r} "
265265 )
@@ -435,15 +435,19 @@ def cnv_coverage_calls(
435435
436436 debug ("normalise parameters" )
437437 regions : List [Region ] = parse_multi_region (self , region )
438+ prepared_sample_set = self ._prep_sample_sets_param (sample_sets = sample_set )[0 ]
439+
440+ # Delete original parameters to prevent accidental use.
438441 del region
442+ del sample_set
439443
440444 debug ("access data and concatenate as needed" )
441445 lx = []
442446 for r in regions :
443447 debug ("obtain coverage calls for the contig" )
444448 x = self ._cnv_coverage_calls_dataset (
445449 contig = r .contig ,
446- sample_set = sample_set ,
450+ sample_set = prepared_sample_set ,
447451 analysis = analysis ,
448452 inline_array = inline_array ,
449453 chunks = chunks ,
@@ -462,6 +466,38 @@ def cnv_coverage_calls(
462466 lx .append (x )
463467 ds = simple_xarray_concat (lx , dim = DIM_VARIANT )
464468
469+ # Filter the samples using this default sample query.
470+ # For example, this might filter out non-surveillance samples.
471+ prepared_sample_query = self ._prep_sample_query_param (sample_query = "" )
472+
473+ # Get the relevant sample metadata.
474+ df_samples = self .sample_metadata (sample_sets = prepared_sample_set )
475+
476+ # Determine which samples match the sample query.
477+ if prepared_sample_query != "" :
478+ loc_samples = df_samples .eval (prepared_sample_query )
479+ else :
480+ loc_samples = pd .Series (True , index = df_samples .index )
481+
482+ # Raise an error if no samples match the sample query.
483+ if not loc_samples .any ():
484+ raise ValueError (f"No samples found for query { prepared_sample_query !r} " )
485+
486+ # Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
487+ relevant_sample_ids = df_samples .loc [loc_samples , "sample_id" ].values
488+
489+ # Get all the sample ids from the unfiltered CNV coverage calls Dataset.
490+ ds_sample_ids = ds .coords ["sample_id" ].values
491+
492+ # Get the indices of samples in the CNV coverage calls Dataset that match the relevant sample ids.
493+ # Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
494+ relevant_sample_indices = np .where (np .isin (ds_sample_ids , relevant_sample_ids ))[
495+ 0
496+ ]
497+
498+ # Select only the relevant samples from the CNV coverage calls Dataset.
499+ ds = ds .isel (samples = relevant_sample_indices )
500+
465501 return ds
466502
467503 @check_types
0 commit comments