Skip to content

Commit 3bad016

Browse files
committed
WIP: dev support for unrestricted_use_only, surveillance_use_only
1 parent a9f44c4 commit 3bad016

2 files changed

Lines changed: 41 additions & 7 deletions

File tree

malariagen_data/anoph/aim_data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,10 @@ def aim_calls(
172172
sample_query_options = sample_query_options or {}
173173

174174
# Determine which samples match the sample query.
175-
loc_samples = df_samples.eval(
176-
prepared_sample_query, **sample_query_options
177-
).values
175+
loc_samples = df_samples.eval(prepared_sample_query, **sample_query_options)
178176

179177
# Raise an error if no samples match the sample query.
180-
if np.count_nonzero(loc_samples) == 0:
178+
if not loc_samples.any():
181179
raise ValueError(
182180
f"No samples found for query {prepared_sample_query!r}"
183181
)

malariagen_data/anoph/cnv_data.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,10 @@ def cnv_hmm(
256256
# Determine which samples match the sample query.
257257
loc_samples = df_samples.eval(
258258
prepared_sample_query, **sample_query_options
259-
).values
259+
)
260260

261261
# Raise an error if no samples match the sample query.
262-
if np.count_nonzero(loc_samples) == 0:
262+
if not loc_samples.any():
263263
raise ValueError(
264264
f"No samples found for query {prepared_sample_query!r}"
265265
)
@@ -435,15 +435,19 @@ def cnv_coverage_calls(
435435

436436
debug("normalise parameters")
437437
regions: List[Region] = parse_multi_region(self, region)
438+
prepared_sample_set = self._prep_sample_sets_param(sample_sets=sample_set)[0]
439+
440+
# Delete original parameters to prevent accidental use.
438441
del region
442+
del sample_set
439443

440444
debug("access data and concatenate as needed")
441445
lx = []
442446
for r in regions:
443447
debug("obtain coverage calls for the contig")
444448
x = self._cnv_coverage_calls_dataset(
445449
contig=r.contig,
446-
sample_set=sample_set,
450+
sample_set=prepared_sample_set,
447451
analysis=analysis,
448452
inline_array=inline_array,
449453
chunks=chunks,
@@ -462,6 +466,38 @@ def cnv_coverage_calls(
462466
lx.append(x)
463467
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)
464468

469+
# Filter the samples using this default sample query.
470+
# For example, this might filter out non-surveillance samples.
471+
prepared_sample_query = self._prep_sample_query_param(sample_query="")
472+
473+
# Get the relevant sample metadata.
474+
df_samples = self.sample_metadata(sample_sets=prepared_sample_set)
475+
476+
# Determine which samples match the sample query.
477+
if prepared_sample_query != "":
478+
loc_samples = df_samples.eval(prepared_sample_query)
479+
else:
480+
loc_samples = pd.Series(True, index=df_samples.index)
481+
482+
# Raise an error if no samples match the sample query.
483+
if not loc_samples.any():
484+
raise ValueError(f"No samples found for query {prepared_sample_query!r}")
485+
486+
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
487+
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
488+
489+
# Get all the sample ids from the unfiltered CNV coverage calls Dataset.
490+
ds_sample_ids = ds.coords["sample_id"].values
491+
492+
# Get the indices of samples in the CNV coverage calls Dataset that match the relevant sample ids.
493+
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
494+
relevant_sample_indices = np.where(np.isin(ds_sample_ids, relevant_sample_ids))[
495+
0
496+
]
497+
498+
# Select only the relevant samples from the CNV coverage calls Dataset.
499+
ds = ds.isel(samples=relevant_sample_indices)
500+
465501
return ds
466502

467503
@check_types

0 commit comments

Comments
 (0)