Skip to content

Commit 88347b0

Browse files
committed
WIP: update cnv_discordant_read_calls to honour constructor params
1 parent 3bad016 commit 88347b0

1 file changed

Lines changed: 37 additions & 21 deletions

File tree

malariagen_data/anoph/cnv_data.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def cnv_hmm(
200200
prepared_sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
201201
prepared_sample_query = self._prep_sample_query_param(sample_query=sample_query)
202202
regions: List[Region] = parse_multi_region(self, region)
203+
204+
# Delete original parameters to prevent accidental use.
205+
del sample_sets
206+
del sample_query
203207
del region
204208

205209
with self._spinner("Access CNV HMM data"):
@@ -244,7 +248,6 @@ def cnv_hmm(
244248
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)
245249

246250
debug("handle sample query")
247-
248251
# If there's a sample query...
249252
if prepared_sample_query is not None:
250253
# Get the relevant sample metadata.
@@ -640,16 +643,20 @@ def cnv_discordant_read_calls(
640643
# CNV alleles have unknown start or end coordinates.
641644

642645
debug("normalise parameters")
643-
sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
644-
sample_query = self._prep_sample_query_param(sample_query=sample_query)
646+
prepared_sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
647+
prepared_sample_query = self._prep_sample_query_param(sample_query=sample_query)
645648
if isinstance(contig, str):
646649
contig = [contig]
647650

651+
# Delete original parameters to prevent accidental use.
652+
del sample_sets
653+
del sample_query
654+
648655
debug("access data and concatenate as needed")
649656
lx = []
650657
for c in contig:
651658
ly = []
652-
for s in sample_sets:
659+
for s in prepared_sample_sets:
653660
y = self._cnv_discordant_read_calls_dataset(
654661
contig=c,
655662
sample_set=s,
@@ -673,30 +680,39 @@ def cnv_discordant_read_calls(
673680
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)
674681

675682
debug("handle sample query")
676-
if sample_query is not None:
683+
684+
# If there's a sample query...
685+
if prepared_sample_query is not None:
677686
debug("load sample metadata")
678-
df_samples = self.sample_metadata(sample_sets=sample_sets)
687+
# Get the relevant sample metadata.
688+
df_samples = self.sample_metadata(sample_sets=prepared_sample_sets)
689+
690+
# If there are no sample query options, then default to an empty dict.
691+
sample_query_options = sample_query_options or {}
692+
693+
# Determine which samples match the sample query.
694+
loc_samples = df_samples.eval(prepared_sample_query, **sample_query_options)
679695

680-
if df_samples.empty:
696+
# Raise an error if no samples match the sample query.
697+
if not loc_samples.any():
681698
raise ValueError(
682-
f"No samples found for sample sets {sample_sets!r}. These samples might be unavailable or irrelevant with respect to settings."
699+
f"No samples found for query {prepared_sample_query!r}"
683700
)
684701

685-
debug("align sample metadata with CNV data")
686-
cnv_samples = ds["sample_id"].values.tolist()
687-
df_samples_cnv = (
688-
df_samples.set_index("sample_id").loc[cnv_samples].reset_index()
689-
)
702+
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
703+
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
690704

691-
debug("apply the query")
692-
sample_query_options = sample_query_options or {}
693-
loc_query_samples = df_samples_cnv.eval(
694-
sample_query, **sample_query_options
695-
).values
696-
if np.count_nonzero(loc_query_samples) == 0:
697-
raise ValueError(f"No samples found for query {sample_query!r}")
705+
# Get all the sample ids from the unfiltered CNV discordant reads Dataset.
706+
ds_sample_ids = ds.coords["sample_id"].values
707+
708+
# Get the indices of samples in the CNV discordant reads Dataset that match the relevant sample ids.
709+
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
710+
relevant_sample_indices = np.where(
711+
np.isin(ds_sample_ids, relevant_sample_ids)
712+
)[0]
698713

699-
ds = ds.isel(samples=loc_query_samples)
714+
# Select only the relevant samples from the CNV discordant reads Dataset.
715+
ds = ds.isel(samples=relevant_sample_indices)
700716

701717
return ds
702718

0 commit comments

Comments
 (0)