@@ -753,7 +753,7 @@ def _locate_site_class(
753753 try :
754754 loc_ann = self ._cache_locate_site_class [cache_key ]
755755
756- except KeyError :
756+ except KeyError as exc :
757757 # Access site annotations data.
758758 ds_ann = self ._site_annotations_raw (
759759 contig = region .contig ,
@@ -877,7 +877,7 @@ def _locate_site_class(
877877 ) | ((seq_cls == SEQ_CLS_DOWNSTREAM ) & (seq_relpos_start > 10_000 ))
878878
879879 else :
880- raise NotImplementedError (site_class )
880+ raise NotImplementedError (site_class ) from exc
881881
882882 # N.B., site annotations data are provided for every position in the genome. We need to
883883 # therefore subset to SNP positions.
@@ -1007,33 +1007,44 @@ def snp_calls(
10071007 )
10081008
10091009 # Normalise parameters.
1010- prepared_regions : Tuple [Region , ...] = tuple (parse_multi_region (self , region ))
1011- prepared_sample_sets : Tuple [str , ...] = tuple (
1012- self ._prep_sample_sets_param (sample_sets = sample_sets )
1013- )
1014-
1015- sample_query_prepped = self ._prep_sample_query_param (sample_query = sample_query )
1016-
1017- if sample_indices is not None :
1018- prepared_sample_indices : Optional [Tuple [int , ...]] = tuple (sample_indices )
1019- else :
1020- prepared_sample_indices = sample_indices
1021-
1010+ prepared_regions = parse_multi_region (self , region )
10221011 prepared_site_mask = self ._prep_optional_site_mask_param (site_mask = site_mask )
10231012
1013+ # Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
1014+ # So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
1015+ (
1016+ prepared_sample_sets ,
1017+ prepared_sample_indices ,
1018+ ) = self ._prep_sample_selection_cache_params (
1019+ sample_sets = sample_sets ,
1020+ sample_query = sample_query ,
1021+ sample_query_options = sample_query_options ,
1022+ sample_indices = sample_indices ,
1023+ )
1024+
10241025 # Delete original parameters to prevent accidental use.
10251026 del sample_sets
10261027 del sample_query
10271028 del sample_indices
10281029 del region
10291030 del site_mask
10301031
1032+ # Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1033+ prepared_regions_tuple : Tuple [Region , ...] = tuple (prepared_regions )
1034+ prepared_sample_sets_tuple : Optional [Tuple [str , ...]] = (
1035+ tuple (prepared_sample_sets ) if prepared_sample_sets is not None else None
1036+ )
1037+ prepared_sample_indices_tuple : Optional [Tuple [int , ...]] = (
1038+ tuple (prepared_sample_indices )
1039+ if prepared_sample_indices is not None
1040+ else None
1041+ )
1042+
1043+ # Note: `_snp_calls` should only take `sample_indices`, not `sample_query`, to facilitate caching.
10311044 return self ._snp_calls (
1032- regions = prepared_regions ,
1033- sample_sets = prepared_sample_sets ,
1034- sample_query = sample_query_prepped ,
1035- sample_query_options = sample_query_options ,
1036- sample_indices = prepared_sample_indices ,
1045+ regions = prepared_regions_tuple ,
1046+ sample_sets = prepared_sample_sets_tuple ,
1047+ sample_indices = prepared_sample_indices_tuple ,
10371048 site_mask = prepared_site_mask ,
10381049 site_class = site_class ,
10391050 cohort_size = cohort_size ,
@@ -1127,8 +1138,6 @@ def _snp_calls(
11271138 * ,
11281139 regions : Tuple [Region , ...],
11291140 sample_sets ,
1130- sample_query ,
1131- sample_query_options ,
11321141 sample_indices ,
11331142 site_mask ,
11341143 site_class ,
@@ -1139,10 +1148,15 @@ def _snp_calls(
11391148 inline_array ,
11401149 chunks ,
11411150 ):
1142- # Note: sample_sets and sample_query should be "prepared" before being passed to this private function.
1151+ ## Get SNP calls and concatenate multiple sample sets and/or regions.
1152+
1153+ # Note: sample_sets should be "prepared" before being passed to this private function.
1154+
1155+ # Note: `_snp_calls` should only take `sample_indices`, not `sample_query`.
1156+ # Use `_prep_sample_selection_cache_params` to convert `sample_query` to `sample_indices`.
1157+
1158+ # Note: we don't cache different sample_indices subsets, which are selected below.
11431159
1144- # Get SNP calls and concatenate multiple sample sets and/or regions.
1145- # Note: we don't cache different sample_query or sample_indices subsets.
11461160 ds = self ._cached_snp_calls (
11471161 regions = regions ,
11481162 sample_sets = sample_sets ,
@@ -1153,22 +1167,41 @@ def _snp_calls(
11531167 )
11541168
11551169 # Handle sample selection.
1156- if sample_query is not None :
1170+ if sample_indices is not None :
1171+ # Note: `sample_indices` could be any tuple of integers, while the `ds` DataSet will contain data for all samples in the `sample_sets`.
1172+ # In other words, the internal `sample_query` is not being applied to `ds`.
1173+ # We need to get the filtered set of samples from `sample_metadata` and then select samples based on that set.
1174+
11571175 # Get the relevant sample metadata.
1158- df_samples = self .sample_metadata (sample_sets = sample_sets )
1176+ relevant_samples_df = self .sample_metadata (sample_sets = sample_sets )
11591177
1160- # If there are no sample query options, then default to an empty dict .
1161- sample_query_options = sample_query_options or {}
1178+ # We need to select only the samples that are identified by the `sample_indices` tuple relative to the results of `sample_metadata` .
1179+ # However, the `ds` DataSet contains data for all samples in the `sample_sets`, regardless of any internal `sample_query`.
11621180
1163- ds = self ._filter_sample_dataset (
1164- ds = ds ,
1165- df_samples = df_samples ,
1166- sample_query = sample_query ,
1167- sample_query_options = sample_query_options ,
1168- )
1181+ # Get the samples identified via `sample_indices`.
1182+ # Note: this might raise `IndexingError` if the user provides bad indices, e.g. "positional indexers are out-of-bounds".
1183+ # Note: `sample_indices` needs to be a list rather than tuple for `iloc`, otherwise `IndexingError`, e.g. "Too many indexers".
1184+ sample_indices_as_list = list (sample_indices )
1185+ selected_samples_df = relevant_samples_df .iloc [sample_indices_as_list ]
11691186
1170- elif sample_indices is not None :
1171- ds = ds .isel (samples = list (sample_indices ))
1187+ # Get the selected sample ids from the sample metadata DataFrame.
1188+ relevant_sample_ids = selected_samples_df ["sample_id" ].values
1189+
1190+ # Get all the sample ids from the unfiltered Dataset.
1191+ ds_sample_ids = ds .coords ["sample_id" ].values
1192+
1193+ # Get the indices of samples in the Dataset that match the relevant sample ids.
1194+ # Note: we use `[0]` to get the first element of the tuple returned by `np.where`.
1195+ relevant_sample_indices = np .where (
1196+ np .isin (ds_sample_ids , relevant_sample_ids )
1197+ )[0 ]
1198+
1199+ # Preserve the behaviour of raising a `ValueError` instead of empty results.
1200+ if relevant_sample_indices .size == 0 :
1201+ raise ValueError ("No relevant samples found." )
1202+
1203+ # Select only the relevant samples from the Dataset.
1204+ ds = ds .isel (samples = relevant_sample_indices )
11721205
11731206 # Handle cohort size, overrides min and max.
11741207 if cohort_size is not None :
@@ -1939,7 +1972,7 @@ def _biallelic_diplotypes(
19391972 inline_array ,
19401973 chunks ,
19411974 ):
1942- # Note: this uses sample_indices and should not expect a sample_query.
1975+ # Note: this function uses sample_indices and should not expect a sample_query.
19431976
19441977 # Access biallelic SNPs.
19451978 ds = self .biallelic_snp_calls (
0 commit comments