Skip to content

Commit 4b9904f

Browse files
committed
Amend snp_genotypes to handle sample_indices when surveillance_use_only
1 parent 6c4e74f commit 4b9904f

1 file changed

Lines changed: 51 additions & 20 deletions

File tree

malariagen_data/anoph/snp_data.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -461,24 +461,26 @@ def snp_genotypes(
461461
sample_query=sample_query, sample_indices=sample_indices
462462
)
463463

464-
# Normalise parameters.
465-
sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
464+
# Prepare parameters.
465+
prepared_sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
466+
prepared_sample_query = self._prep_sample_query_param(sample_query=sample_query)
467+
prepared_regions: List[Region] = parse_multi_region(self, region)
468+
prepared_site_mask = self._prep_optional_site_mask_param(site_mask=site_mask)
469+
470+
# Delete original parameters to prevent accidental use.
466471
del sample_sets
467-
sample_query_prepped = self._prep_sample_query_param(sample_query=sample_query)
468472
del sample_query
469-
regions: List[Region] = parse_multi_region(self, region)
470473
del region
471-
site_mask_prepped = self._prep_optional_site_mask_param(site_mask=site_mask)
472474
del site_mask
473475

474476
with self._spinner("Access SNP genotypes"):
475477
# Concatenate multiple sample sets and/or contigs.
476478
lx = []
477-
for r in regions:
479+
for r in prepared_regions:
478480
contig = r.contig
479481
ly = []
480482

481-
for s in sample_sets_prepped:
483+
for s in prepared_sample_sets:
482484
y = self._snp_genotypes_for_contig(
483485
contig=contig,
484486
sample_set=s,
@@ -508,24 +510,53 @@ def snp_genotypes(
508510
d = da_concat(lx, axis=0)
509511

510512
# Apply site filters if requested.
511-
if site_mask_prepped is not None:
513+
if prepared_site_mask is not None:
512514
loc_sites = self.site_filters(
513-
region=regions,
514-
mask=site_mask_prepped,
515+
region=prepared_regions,
516+
mask=prepared_site_mask,
515517
)
516518
d = da_compress(loc_sites, d, axis=0)
517519

518-
# Apply sample selection if requested.
519-
if sample_query_prepped is not None:
520-
df_samples = self.sample_metadata(sample_sets=sample_sets_prepped)
521-
sample_query_options = sample_query_options or {}
522-
loc_samples = df_samples.eval(
523-
sample_query_prepped, **sample_query_options
524-
).values
525-
if np.count_nonzero(loc_samples) == 0:
526-
raise ValueError(f"No samples found for query {sample_query_prepped!r}")
520+
# Apply the sample_query, if there is one.
521+
# Note: this might have been internally modified, e.g. `is_surveillance == True`.
522+
if prepared_sample_query is not None:
523+
# Note: the unfiltered Dask array `d` is not aligned with the filtered `sample_metadata`,
524+
# so we cannot use filtered `sample_metadata` to get the relevant boolean filter.
525+
526+
# Note: the unfiltered Dask array `d` does not contain sample identifiers,
527+
# so we cannot use a list of relevant sample ids to produce the boolean filter directly.
528+
529+
# Note: we can first determine the list of relevant sample ids using filtered `sample_metadata`,
530+
# then use the unfiltered `general_metadata` to determine the appropriate boolean filter.
531+
532+
df_filtered_samples = self.sample_metadata(
533+
sample_sets=prepared_sample_sets,
534+
sample_query=prepared_sample_query,
535+
sample_query_options=sample_query_options,
536+
)
537+
538+
# Raise an error if no samples match the sample query.
539+
if len(df_filtered_samples) == 0:
540+
raise ValueError(
541+
f"No samples found for query {prepared_sample_query!r}"
542+
)
543+
544+
# Get the list of unfiltered samples, in order to produce an aligned boolean filter.
545+
df_unfiltered_samples = self.general_metadata(
546+
sample_sets=prepared_sample_sets
547+
)
548+
549+
# Get a boolean array for unfiltered data, indicating which samples match the query.
550+
loc_samples = df_unfiltered_samples["sample_id"].isin(
551+
df_filtered_samples["sample_id"]
552+
)
553+
554+
# Filter the Dask array using the boolean array.
527555
d = da.compress(loc_samples, d, axis=1)
528-
elif sample_indices is not None:
556+
557+
# Apply the sample_indices, if there are any.
558+
# Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
559+
if sample_indices is not None:
529560
d = da.take(d, sample_indices, axis=1)
530561

531562
return d

0 commit comments

Comments
 (0)