Skip to content

Commit 6c4e74f

Browse files
committed
WIP: handle sample_indices when surveillance_use_only
1 parent 9b7d6cc commit 6c4e74f

2 files changed

Lines changed: 133 additions & 44 deletions

File tree

malariagen_data/anoph/sample_metadata.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,36 @@ def _prep_sample_selection_cache_params(
10601060
sample_query_options: Optional[base_params.sample_query_options],
10611061
sample_indices: Optional[base_params.sample_indices],
10621062
) -> Tuple[List[str], Optional[List[int]]]:
1063+
# Check that either sample_query xor sample_indices are provided.
1064+
base_params.validate_sample_selection_params(
1065+
sample_query=sample_query, sample_indices=sample_indices
1066+
)
1067+
1068+
# Resolve query to a list of integers for more cache hits - we
1069+
# do this because there are different ways to write the same pandas
1070+
# query, and so it's better to evaluate the query and use a list of
1071+
# integer indices instead.
1072+
1073+
# Scenario 1: No `sample_query` nor `sample_indices` were given,
1074+
# and there is no internal `sample_query`,
1075+
# so no `sample_indices` will be returned.
1076+
1077+
# Scenario 2: No `sample_query` nor `sample_indices` were given,
1078+
# but there is an internal `sample_query`,
1079+
# which will be converted into `sample_indices` and returned.
1080+
1081+
# Scenario 3: Only `sample_query` has been provided,
1082+
# which will be converted into `sample_indices` and returned.
1083+
# This will be handled the same as Scenario 2.
1084+
1085+
# Scenario 4: Only `sample_indices` has been provided,
1086+
# and there is no internal `sample_query`,
1087+
# simply return `sample_indices`.
1088+
1089+
# Scenario 5: Only `sample_indices` has been provided,
1090+
# but there is also an internal `sample_query`, still return `sample_indices`,
1091+
# which ought to already align with `sample_metadata`.
1092+
10631093
# Normalise sample sets.
10641094
prepared_sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
10651095
prepared_sample_query = self._prep_sample_query_param(sample_query=sample_query)
@@ -1068,20 +1098,46 @@ def _prep_sample_selection_cache_params(
10681098
del sample_sets
10691099
del sample_query
10701100

1071-
if prepared_sample_query is not None:
1072-
# Resolve query to a list of integers for more cache hits - we
1073-
# do this because there are different ways to write the same pandas
1074-
# query, and so it's better to evaluate the query and use a list of
1075-
# integer indices instead.
1101+
# Start with assuming there are no sample indices.
1102+
# This can be returned if there is no `prepared_sample_query` nor `sample_indices`.
1103+
prepared_sample_indices = None
1104+
1105+
# If there is a `prepared_sample_query` but no `sample_indices`...
1106+
if prepared_sample_query is not None and sample_indices is None:
1107+
# Get the unfiltered sample metadata for the given sample sets.
1108+
# Note: we don't want to pass the `sample_query` to `sample_metadata` here
1109+
# because we want to get the sample indices that represent the `sample_query`.
10761110
df_samples = self.sample_metadata(sample_sets=prepared_sample_sets)
1111+
1112+
# Default the sample_query_options to an empty dict.
10771113
sample_query_options = sample_query_options or {}
1114+
10781115
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
1116+
# Get the Pandas Series as a NumPy array of Boolean values.
1117+
# Note: if `prepared_sample_query` is an internal query, this will select all samples,
1118+
# since `sample_metadata` should have already applied the internal query.
10791119
loc_samples = df_samples.eval(
10801120
prepared_sample_query, **sample_query_options, engine="python"
10811121
).values
1082-
sample_indices = np.nonzero(loc_samples)[0].tolist()
10831122

1084-
return prepared_sample_sets, sample_indices
1123+
# Convert the sample indices to a list.
1124+
# Get the indices of the True values in the Boolean array and convert it to a list of integers.
1125+
prepared_sample_indices = np.nonzero(loc_samples)[0].tolist()
1126+
1127+
# If there is a `prepared_sample_query` and a `sample_indices`...
1128+
elif prepared_sample_query is not None and sample_indices is not None:
1129+
# Given that we don't allow both `sample_query` and `sample_indices` params in this function,
1130+
# we can deduce that the `prepared_sample_query` has resulted from an internal query.
1131+
# Given that `sample_indices` should be aligned with the results of `sample_metadata`,
1132+
# which should already apply the internal query, simply return the given `sample_indices`.
1133+
1134+
prepared_sample_indices = sample_indices
1135+
1136+
# If there is no `prepared_sample_query` but there is a `sample_indices`...
1137+
elif prepared_sample_query is None and sample_indices is not None:
1138+
prepared_sample_indices = sample_indices
1139+
1140+
return prepared_sample_sets, prepared_sample_indices
10851141

10861142
def _results_cache_add_analysis_params(self, params: dict):
10871143
super()._results_cache_add_analysis_params(params)

malariagen_data/anoph/snp_data.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def _locate_site_class(
753753
try:
754754
loc_ann = self._cache_locate_site_class[cache_key]
755755

756-
except KeyError:
756+
except KeyError as exc:
757757
# Access site annotations data.
758758
ds_ann = self._site_annotations_raw(
759759
contig=region.contig,
@@ -877,7 +877,7 @@ def _locate_site_class(
877877
) | ((seq_cls == SEQ_CLS_DOWNSTREAM) & (seq_relpos_start > 10_000))
878878

879879
else:
880-
raise NotImplementedError(site_class)
880+
raise NotImplementedError(site_class) from exc
881881

882882
# N.B., site annotations data are provided for every position in the genome. We need to
883883
# therefore subset to SNP positions.
@@ -1007,33 +1007,44 @@ def snp_calls(
10071007
)
10081008

10091009
# Normalise parameters.
1010-
prepared_regions: Tuple[Region, ...] = tuple(parse_multi_region(self, region))
1011-
prepared_sample_sets: Tuple[str, ...] = tuple(
1012-
self._prep_sample_sets_param(sample_sets=sample_sets)
1013-
)
1014-
1015-
sample_query_prepped = self._prep_sample_query_param(sample_query=sample_query)
1016-
1017-
if sample_indices is not None:
1018-
prepared_sample_indices: Optional[Tuple[int, ...]] = tuple(sample_indices)
1019-
else:
1020-
prepared_sample_indices = sample_indices
1021-
1010+
prepared_regions = parse_multi_region(self, region)
10221011
prepared_site_mask = self._prep_optional_site_mask_param(site_mask=site_mask)
10231012

1013+
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
1014+
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
1015+
(
1016+
prepared_sample_sets,
1017+
prepared_sample_indices,
1018+
) = self._prep_sample_selection_cache_params(
1019+
sample_sets=sample_sets,
1020+
sample_query=sample_query,
1021+
sample_query_options=sample_query_options,
1022+
sample_indices=sample_indices,
1023+
)
1024+
10241025
# Delete original parameters to prevent accidental use.
10251026
del sample_sets
10261027
del sample_query
10271028
del sample_indices
10281029
del region
10291030
del site_mask
10301031

1032+
# Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1033+
prepared_regions_tuple: Tuple[Region, ...] = tuple(prepared_regions)
1034+
prepared_sample_sets_tuple: Optional[Tuple[str, ...]] = (
1035+
tuple(prepared_sample_sets) if prepared_sample_sets is not None else None
1036+
)
1037+
prepared_sample_indices_tuple: Optional[Tuple[int, ...]] = (
1038+
tuple(prepared_sample_indices)
1039+
if prepared_sample_indices is not None
1040+
else None
1041+
)
1042+
1043+
# Note: `_snp_calls` should only take `sample_indices`, not `sample_query`, to facilitate caching.
10311044
return self._snp_calls(
1032-
regions=prepared_regions,
1033-
sample_sets=prepared_sample_sets,
1034-
sample_query=sample_query_prepped,
1035-
sample_query_options=sample_query_options,
1036-
sample_indices=prepared_sample_indices,
1045+
regions=prepared_regions_tuple,
1046+
sample_sets=prepared_sample_sets_tuple,
1047+
sample_indices=prepared_sample_indices_tuple,
10371048
site_mask=prepared_site_mask,
10381049
site_class=site_class,
10391050
cohort_size=cohort_size,
@@ -1127,8 +1138,6 @@ def _snp_calls(
11271138
*,
11281139
regions: Tuple[Region, ...],
11291140
sample_sets,
1130-
sample_query,
1131-
sample_query_options,
11321141
sample_indices,
11331142
site_mask,
11341143
site_class,
@@ -1139,10 +1148,15 @@ def _snp_calls(
11391148
inline_array,
11401149
chunks,
11411150
):
1142-
# Note: sample_sets and sample_query should be "prepared" before being passed to this private function.
1151+
## Get SNP calls and concatenate multiple sample sets and/or regions.
1152+
1153+
# Note: sample_sets should be "prepared" before being passed to this private function.
1154+
1155+
# Note: `_snp_calls` should only take `sample_indices`, not `sample_query`.
1156+
# Use `_prep_sample_selection_cache_params` to convert `sample_query` to `sample_indices`.
1157+
1158+
# Note: we don't cache different sample_indices subsets, which are selected below.
11431159

1144-
# Get SNP calls and concatenate multiple sample sets and/or regions.
1145-
# Note: we don't cache different sample_query or sample_indices subsets.
11461160
ds = self._cached_snp_calls(
11471161
regions=regions,
11481162
sample_sets=sample_sets,
@@ -1153,22 +1167,41 @@ def _snp_calls(
11531167
)
11541168

11551169
# Handle sample selection.
1156-
if sample_query is not None:
1170+
if sample_indices is not None:
1171+
# Note: `sample_indices` could be any tuple of integers, while the `ds` DataSet will contain data for all samples in the `sample_sets`.
1172+
# In other words, the internal `sample_query` is not being applied to `ds`.
1173+
# We need to get the filtered set of samples from `sample_metadata` and then select samples based on that set.
1174+
11571175
# Get the relevant sample metadata.
1158-
df_samples = self.sample_metadata(sample_sets=sample_sets)
1176+
relevant_samples_df = self.sample_metadata(sample_sets=sample_sets)
11591177

1160-
# If there are no sample query options, then default to an empty dict.
1161-
sample_query_options = sample_query_options or {}
1178+
# We need to select only the samples that are identified by the `sample_indices` tuple relative to the results of `sample_metadata`.
1179+
# However, the `ds` DataSet contains data for all samples in the `sample_sets`, regardless of any internal `sample_query`.
11621180

1163-
ds = self._filter_sample_dataset(
1164-
ds=ds,
1165-
df_samples=df_samples,
1166-
sample_query=sample_query,
1167-
sample_query_options=sample_query_options,
1168-
)
1181+
# Get the samples identified via `sample_indices`.
1182+
# Note: this might raise `IndexingError` if the user provides bad indices, e.g. "positional indexers are out-of-bounds".
1183+
# Note: `sample_indices` needs to be a list rather than tuple for `iloc`, otherwise `IndexingError`, e.g. "Too many indexers".
1184+
sample_indices_as_list = list(sample_indices)
1185+
selected_samples_df = relevant_samples_df.iloc[sample_indices_as_list]
11691186

1170-
elif sample_indices is not None:
1171-
ds = ds.isel(samples=list(sample_indices))
1187+
# Get the selected sample ids from the sample metadata DataFrame.
1188+
relevant_sample_ids = selected_samples_df["sample_id"].values
1189+
1190+
# Get all the sample ids from the unfiltered Dataset.
1191+
ds_sample_ids = ds.coords["sample_id"].values
1192+
1193+
# Get the indices of samples in the Dataset that match the relevant sample ids.
1194+
# Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
1195+
relevant_sample_indices = np.where(
1196+
np.isin(ds_sample_ids, relevant_sample_ids)
1197+
)[0]
1198+
1199+
# Preserve the behaviour of raising a `ValueError` instead of empty results.
1200+
if relevant_sample_indices.size == 0:
1201+
raise ValueError("No relevant samples found.")
1202+
1203+
# Select only the relevant samples from the Dataset.
1204+
ds = ds.isel(samples=relevant_sample_indices)
11721205

11731206
# Handle cohort size, overrides min and max.
11741207
if cohort_size is not None:
@@ -1939,7 +1972,7 @@ def _biallelic_diplotypes(
19391972
inline_array,
19401973
chunks,
19411974
):
1942-
# Note: this uses sample_indices and should not expect a sample_query.
1975+
# Note: this function uses sample_indices and should not expect a sample_query.
19431976

19441977
# Access biallelic SNPs.
19451978
ds = self.biallelic_snp_calls(

0 commit comments

Comments
 (0)