Skip to content

Commit 65fd83c

Browse files
committed
WIP: dev support for surveillance_use_only, unrestricted_use_only params
1 parent e16eabb commit 65fd83c

7 files changed

Lines changed: 167 additions & 149 deletions

File tree

malariagen_data/anoph/aim_data.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -171,29 +171,12 @@ def aim_calls(
171171
# If there are no sample query options, then default to an empty dict.
172172
sample_query_options = sample_query_options or {}
173173

174-
# Determine which samples match the sample query.
175-
loc_samples = df_samples.eval(prepared_sample_query, **sample_query_options)
176-
177-
# Raise an error if no samples match the sample query.
178-
if not loc_samples.any():
179-
raise ValueError(
180-
f"No samples found for query {prepared_sample_query!r}"
181-
)
182-
183-
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
184-
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
185-
186-
# Get all the sample ids from the unfiltered AIM calls Dataset.
187-
ds_sample_ids = ds.coords["sample_id"].values
188-
189-
# Get the indices of samples in the AIM calls Dataset that match the relevant sample ids.
190-
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
191-
relevant_sample_indices = np.where(
192-
np.isin(ds_sample_ids, relevant_sample_ids)
193-
)[0]
194-
195-
# Select only the relevant samples from the AIM calls Dataset.
196-
ds = ds.isel(samples=relevant_sample_indices)
174+
ds = self._filter_sample_dataset(
175+
ds=ds,
176+
df_samples=df_samples,
177+
sample_query=prepared_sample_query,
178+
sample_query_options=sample_query_options,
179+
)
197180

198181
return ds
199182

malariagen_data/anoph/base.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tqdm.auto import tqdm as tqdm_auto # type: ignore
2929
from tqdm.dask import TqdmCallback # type: ignore
3030
from yaspin import yaspin # type: ignore
31+
import xarray as xr
3132

3233
from ..util import (
3334
CacheMiss,
@@ -933,6 +934,45 @@ def _prep_sample_query_param(
933934

934935
return prepped_sample_query
935936

937+
def _filter_sample_dataset(
938+
self,
939+
*,
940+
ds: xr.Dataset,
941+
df_samples: pd.DataFrame,
942+
sample_query: str,
943+
sample_query_options: dict,
944+
) -> xr.Dataset:
945+
"""Filters the given Dataset using the given DataFrame and query."""
946+
947+
# Note: "prepare" the params before calling this function.
948+
949+
# Determine which samples match the sample query.
950+
if sample_query != "":
951+
loc_samples = df_samples.eval(sample_query, **sample_query_options)
952+
else:
953+
loc_samples = pd.Series(True, index=df_samples.index)
954+
955+
# Raise an error if no samples match the sample query.
956+
if not loc_samples.any():
957+
raise ValueError(f"No samples found for query {sample_query!r}")
958+
959+
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
960+
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
961+
962+
# Get all the sample ids from the unfiltered Dataset.
963+
ds_sample_ids = ds.coords["sample_id"].values
964+
965+
# Get the indices of samples in the Dataset that match the relevant sample ids.
966+
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
967+
relevant_sample_indices = np.where(np.isin(ds_sample_ids, relevant_sample_ids))[
968+
0
969+
]
970+
971+
# Select only the relevant samples from the Dataset.
972+
ds = ds.isel(samples=relevant_sample_indices)
973+
974+
return ds
975+
936976
def _results_cache_add_analysis_params(self, params: dict):
937977
# Expect sub-classes will override to add any analysis parameters.
938978
pass

malariagen_data/anoph/cnv_data.py

Lines changed: 19 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -256,32 +256,13 @@ def cnv_hmm(
256256
# If there are no sample query options, then default to an empty dict.
257257
sample_query_options = sample_query_options or {}
258258

259-
# Determine which samples match the sample query.
260-
loc_samples = df_samples.eval(
261-
prepared_sample_query, **sample_query_options
259+
ds = self._filter_sample_dataset(
260+
ds=ds,
261+
df_samples=df_samples,
262+
sample_query=prepared_sample_query,
263+
sample_query_options=sample_query_options,
262264
)
263265

264-
# Raise an error if no samples match the sample query.
265-
if not loc_samples.any():
266-
raise ValueError(
267-
f"No samples found for query {prepared_sample_query!r}"
268-
)
269-
270-
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
271-
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
272-
273-
# Get all the sample ids from the unfiltered CNV HMM Dataset.
274-
ds_sample_ids = ds.coords["sample_id"].values
275-
276-
# Get the indices of samples in the CNV HMM Dataset that match the relevant sample ids.
277-
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
278-
relevant_sample_indices = np.where(
279-
np.isin(ds_sample_ids, relevant_sample_ids)
280-
)[0]
281-
282-
# Select only the relevant samples from the CNV HMM Dataset.
283-
ds = ds.isel(samples=relevant_sample_indices)
284-
285266
debug("handle coverage variance filter")
286267
if max_coverage_variance is not None:
287268
cov_var = ds["sample_coverage_variance"].values
@@ -476,30 +457,15 @@ def cnv_coverage_calls(
476457
# Get the relevant sample metadata.
477458
df_samples = self.sample_metadata(sample_sets=prepared_sample_set)
478459

479-
# Determine which samples match the sample query.
480-
if prepared_sample_query != "":
481-
loc_samples = df_samples.eval(prepared_sample_query)
482-
else:
483-
loc_samples = pd.Series(True, index=df_samples.index)
484-
485-
# Raise an error if no samples match the sample query.
486-
if not loc_samples.any():
487-
raise ValueError(f"No samples found for query {prepared_sample_query!r}")
488-
489-
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
490-
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
460+
# If there is no sample query, then default to an empty str.
461+
prepared_sample_query = prepared_sample_query or ""
491462

492-
# Get all the sample ids from the unfiltered CNV coverage calls Dataset.
493-
ds_sample_ids = ds.coords["sample_id"].values
494-
495-
# Get the indices of samples in the CNV coverage calls Dataset that match the relevant sample ids.
496-
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
497-
relevant_sample_indices = np.where(np.isin(ds_sample_ids, relevant_sample_ids))[
498-
0
499-
]
500-
501-
# Select only the relevant samples from the CNV coverage calls Dataset.
502-
ds = ds.isel(samples=relevant_sample_indices)
463+
ds = self._filter_sample_dataset(
464+
ds=ds,
465+
df_samples=df_samples,
466+
sample_query=prepared_sample_query,
467+
sample_query_options={},
468+
)
503469

504470
return ds
505471

@@ -690,29 +656,12 @@ def cnv_discordant_read_calls(
690656
# If there are no sample query options, then default to an empty dict.
691657
sample_query_options = sample_query_options or {}
692658

693-
# Determine which samples match the sample query.
694-
loc_samples = df_samples.eval(prepared_sample_query, **sample_query_options)
695-
696-
# Raise an error if no samples match the sample query.
697-
if not loc_samples.any():
698-
raise ValueError(
699-
f"No samples found for query {prepared_sample_query!r}"
700-
)
701-
702-
# Get the relevant sample ids from the sample metadata DataFrame, using the boolean mask.
703-
relevant_sample_ids = df_samples.loc[loc_samples, "sample_id"].values
704-
705-
# Get all the sample ids from the unfiltered CNV discordant reads Dataset.
706-
ds_sample_ids = ds.coords["sample_id"].values
707-
708-
# Get the indices of samples in the CNV discordant reads Dataset that match the relevant sample ids.
709-
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
710-
relevant_sample_indices = np.where(
711-
np.isin(ds_sample_ids, relevant_sample_ids)
712-
)[0]
713-
714-
# Select only the relevant samples from the CNV discordant reads Dataset.
715-
ds = ds.isel(samples=relevant_sample_indices)
659+
ds = self._filter_sample_dataset(
660+
ds=ds,
661+
df_samples=df_samples,
662+
sample_query=prepared_sample_query,
663+
sample_query_options=sample_query_options,
664+
)
716665

717666
return ds
718667

malariagen_data/anoph/distance.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def biallelic_diplotype_pairwise_distances(
115115
# invalidate any previously cached data.
116116
name = "biallelic_diplotype_pairwise_distances"
117117

118-
# Normalize params for consistent hash value.
118+
## Normalize params for consistent hash value.
119+
120+
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
121+
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
119122
(
120123
sample_sets_prepped,
121124
sample_indices_prepped,
@@ -269,7 +272,10 @@ def njt(
269272
# invalidate any previously cached data.
270273
name = "njt_v1"
271274

272-
# Normalize params for consistent hash value.
275+
## Normalize params for consistent hash value.
276+
277+
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
278+
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
273279
(
274280
sample_sets_prepped,
275281
sample_indices_prepped,

malariagen_data/anoph/pca.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,39 @@ def pca(
8080
) -> Tuple[pca_params.df_pca, pca_params.evr]:
8181
# Change this name if you ever change the behaviour of this function, to
8282
# invalidate any previously cached data.
83-
name = "pca_v4"
83+
name = "pca_v5"
8484

85-
# Normalize params for consistent hash value.
85+
## Normalize params for consistent hash value.
86+
87+
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
88+
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
8689
(
87-
sample_sets_prepped,
88-
sample_indices_prepped,
90+
prepared_sample_sets,
91+
prepared_sample_indices,
8992
) = self._prep_sample_selection_cache_params(
9093
sample_sets=sample_sets,
9194
sample_query=sample_query,
9295
sample_query_options=sample_query_options,
9396
sample_indices=sample_indices,
9497
)
95-
region_prepped = self._prep_region_cache_param(region=region)
96-
site_mask_prepped = self._prep_optional_site_mask_param(site_mask=site_mask)
98+
prepared_region = self._prep_region_cache_param(region=region)
99+
prepared_site_mask = self._prep_optional_site_mask_param(site_mask=site_mask)
100+
101+
# Delete original parameters to prevent accidental use.
102+
del sample_sets
103+
del sample_indices
104+
del sample_query
105+
del sample_query_options
106+
del region
107+
del site_mask
108+
97109
params = dict(
98-
region=region_prepped,
110+
region=prepared_region,
99111
n_snps=n_snps,
100112
thin_offset=thin_offset,
101-
sample_sets=sample_sets_prepped,
102-
sample_indices=sample_indices_prepped,
103-
site_mask=site_mask_prepped,
113+
sample_sets=prepared_sample_sets,
114+
sample_indices=prepared_sample_indices,
115+
site_mask=prepared_site_mask,
104116
site_class=site_class,
105117
min_minor_ac=min_minor_ac,
106118
max_missing_an=max_missing_an,
@@ -127,22 +139,18 @@ def pca(
127139
samples = results["samples"]
128140
loc_keep_fit = results["loc_keep_fit"]
129141

130-
# Load sample metadata.
131-
df_samples = self.sample_metadata(
132-
sample_sets=sample_sets,
133-
)
142+
# Create a new DataFrame containing the PCA coords data.
143+
df_pca = pd.DataFrame(coords, index=samples)
134144

135-
# Ensure aligned with genotype data.
136-
df_samples = df_samples.set_index("sample_id").loc[samples].reset_index()
145+
# Name the DataFrame's columns PC1, PC2, etc.
146+
df_pca.columns = pd.Index([f"PC{i+1}" for i in range(coords.shape[1])])
137147

138-
# Combine coords and sample metadata.
139-
df_coords = pd.DataFrame(
140-
{f"PC{i + 1}": coords[:, i] for i in range(coords.shape[1])}
141-
)
142-
df_pca = df_samples.join(df_coords, how="inner")
143-
# Add a column for which samples were included in fitting.
148+
# Add a column to indicate which samples were included in fitting.
144149
df_pca["pca_fit"] = loc_keep_fit
145150

151+
# Name the index.
152+
df_pca.index.name = "sample_id"
153+
146154
return df_pca, evr
147155

148156
def _pca(

malariagen_data/anoph/sample_metadata.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,20 +1054,26 @@ def _prep_sample_selection_cache_params(
10541054
sample_indices: Optional[base_params.sample_indices],
10551055
) -> Tuple[List[str], Optional[List[int]]]:
10561056
# Normalise sample sets.
1057-
sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
1058-
sample_query = self._prep_sample_query_param(sample_query=sample_query)
1057+
prepared_sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
1058+
prepared_sample_query = self._prep_sample_query_param(sample_query=sample_query)
10591059

1060-
if sample_query is not None:
1060+
# Delete original parameters to prevent accidental use.
1061+
del sample_sets
1062+
del sample_query
1063+
1064+
if prepared_sample_query is not None:
10611065
# Resolve query to a list of integers for more cache hits - we
10621066
# do this because there are different ways to write the same pandas
10631067
# query, and so it's better to evaluate the query and use a list of
10641068
# integer indices instead.
1065-
df_samples = self.sample_metadata(sample_sets=sample_sets)
1069+
df_samples = self.sample_metadata(sample_sets=prepared_sample_sets)
10661070
sample_query_options = sample_query_options or {}
1067-
loc_samples = df_samples.eval(sample_query, **sample_query_options).values
1071+
loc_samples = df_samples.eval(
1072+
prepared_sample_query, **sample_query_options
1073+
).values
10681074
sample_indices = np.nonzero(loc_samples)[0].tolist()
10691075

1070-
return sample_sets, sample_indices
1076+
return prepared_sample_sets, sample_indices
10711077

10721078
def _results_cache_add_analysis_params(self, params: dict):
10731079
super()._results_cache_add_analysis_params(params)

0 commit comments

Comments
 (0)