@@ -461,24 +461,26 @@ def snp_genotypes(
461461 sample_query = sample_query , sample_indices = sample_indices
462462 )
463463
464- # Normalise parameters.
465- sample_sets_prepped = self ._prep_sample_sets_param (sample_sets = sample_sets )
464+ # Prepare parameters.
465+ prepared_sample_sets = self ._prep_sample_sets_param (sample_sets = sample_sets )
466+ prepared_sample_query = self ._prep_sample_query_param (sample_query = sample_query )
467+ prepared_regions : List [Region ] = parse_multi_region (self , region )
468+ prepared_site_mask = self ._prep_optional_site_mask_param (site_mask = site_mask )
469+
470+ # Delete original parameters to prevent accidental use.
466471 del sample_sets
467- sample_query_prepped = self ._prep_sample_query_param (sample_query = sample_query )
468472 del sample_query
469- regions : List [Region ] = parse_multi_region (self , region )
470473 del region
471- site_mask_prepped = self ._prep_optional_site_mask_param (site_mask = site_mask )
472474 del site_mask
473475
474476 with self ._spinner ("Access SNP genotypes" ):
475477 # Concatenate multiple sample sets and/or contigs.
476478 lx = []
477- for r in regions :
479+ for r in prepared_regions :
478480 contig = r .contig
479481 ly = []
480482
481- for s in sample_sets_prepped :
483+ for s in prepared_sample_sets :
482484 y = self ._snp_genotypes_for_contig (
483485 contig = contig ,
484486 sample_set = s ,
@@ -508,24 +510,53 @@ def snp_genotypes(
508510 d = da_concat (lx , axis = 0 )
509511
510512 # Apply site filters if requested.
511- if site_mask_prepped is not None :
513+ if prepared_site_mask is not None :
512514 loc_sites = self .site_filters (
513- region = regions ,
514- mask = site_mask_prepped ,
515+ region = prepared_regions ,
516+ mask = prepared_site_mask ,
515517 )
516518 d = da_compress (loc_sites , d , axis = 0 )
517519
518- # Apply sample selection if requested.
519- if sample_query_prepped is not None :
520- df_samples = self .sample_metadata (sample_sets = sample_sets_prepped )
521- sample_query_options = sample_query_options or {}
522- loc_samples = df_samples .eval (
523- sample_query_prepped , ** sample_query_options
524- ).values
525- if np .count_nonzero (loc_samples ) == 0 :
526- raise ValueError (f"No samples found for query { sample_query_prepped !r} " )
520+ # Apply the sample_query, if there is one.
521+ # Note: this might have been internally modified, e.g. `is_surveillance == True`.
522+ if prepared_sample_query is not None :
523+ # Note: the unfiltered Dask array `d` is not aligned with the filtered `sample_metadata`,
524+ # so we cannot use filtered `sample_metadata` to get the relevant boolean filter.
525+
526+ # Note: the unfiltered Dask array `d` does not contain sample identifiers,
527+ # so we cannot use a list of relevant sample ids to produce the boolean filter directly.
528+
529+ # Note: we can first determine the list of relevant sample ids using filtered `sample_metadata`,
530+ # then use the unfiltered `general_metadata` to determine the appropriate boolean filter.
531+
532+ df_filtered_samples = self .sample_metadata (
533+ sample_sets = prepared_sample_sets ,
534+ sample_query = prepared_sample_query ,
535+ sample_query_options = sample_query_options ,
536+ )
537+
538+ # Raise an error if no samples match the sample query.
539+ if len (df_filtered_samples ) == 0 :
540+ raise ValueError (
541+ f"No samples found for query { prepared_sample_query !r} "
542+ )
543+
544+ # Get the list of unfiltered samples, in order to produce an aligned boolean filter.
545+ df_unfiltered_samples = self .general_metadata (
546+ sample_sets = prepared_sample_sets
547+ )
548+
549+ # Get a boolean array for unfiltered data, indicating which samples match the query.
550+ loc_samples = df_unfiltered_samples ["sample_id" ].isin (
551+ df_filtered_samples ["sample_id" ]
552+ )
553+
554+ # Filter the Dask array using the boolean array.
527555 d = da .compress (loc_samples , d , axis = 1 )
528- elif sample_indices is not None :
556+
557+ # Apply the sample_indices, if there are any.
558+ # Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
559+ if sample_indices is not None :
529560 d = da .take (d , sample_indices , axis = 1 )
530561
531562 return d
0 commit comments