Skip to content

Commit 170ecf3

Browse files
Merge branch 'master' into fix/issue-1303-cloud-storage-retry-backoff
2 parents 960466b + 9c1bb56 commit 170ecf3

File tree

17 files changed

+606
-49
lines changed

17 files changed

+606
-49
lines changed

malariagen_data/anoph/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from numpydoc_decorator import doc # type: ignore
2929
from tqdm.auto import tqdm as tqdm_auto # type: ignore
3030
from tqdm.dask import TqdmCallback # type: ignore
31+
32+
from .safe_query import validate_query
3133
from yaspin import yaspin # type: ignore
3234
import xarray as xr
3335

@@ -980,10 +982,9 @@ def _filter_sample_dataset(
980982

981983
# Determine which samples match the sample query.
982984
if sample_query != "":
983-
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
984-
loc_samples = df_samples.eval(
985-
sample_query, **sample_query_options, engine="python"
986-
)
985+
# Validate the query to prevent arbitrary code execution (GH-1292).
986+
validate_query(sample_query)
987+
loc_samples = df_samples.eval(sample_query, **sample_query_options)
987988
else:
988989
loc_samples = pd.Series(True, index=df_samples.index)
989990

malariagen_data/anoph/cnv_frq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_build_cohorts_from_sample_grouping,
1616
_add_frequency_ci,
1717
)
18+
from .safe_query import validate_query
1819
from ..util import (
1920
_check_types,
2021
_pandas_apply,
@@ -671,6 +672,7 @@ def _gene_cnv_frequencies_advanced(
671672

672673
debug("apply variant query")
673674
if variant_query is not None:
675+
validate_query(variant_query)
674676
loc_variants = df_variants.eval(variant_query).values
675677
# Convert boolean mask to integer indices for NumPy 2.x compatibility
676678
variant_indices = np.where(loc_variants)[0]

malariagen_data/anoph/frq_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def _build_cohorts_from_sample_grouping(
147147
period_str = df_cohorts["period"].astype(str)
148148
df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str
149149

150-
# Apply minimum cohort size.
151-
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
150+
# Apply minimum cohort size using safe boolean indexing.
151+
df_cohorts = df_cohorts.loc[df_cohorts["size"] >= min_cohort_size].reset_index(
152+
drop=True
153+
)
152154

153155
# Early check for no cohorts.
154156
if len(df_cohorts) == 0:

malariagen_data/anoph/genome_features.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def _genome_features_for_contig(self, *, contig: str, attributes: Tuple[str, ...
117117
)
118118
df = self._genome_features(attributes=attributes)
119119

120-
# Apply contig query.
121-
df = df.query(f"contig == '{contig}'")
120+
# Apply contig filter using safe boolean indexing.
121+
df = df.loc[df["contig"] == contig]
122122
return df
123123

124124
def _prep_gff_attributes(
@@ -162,9 +162,9 @@ def genome_features(
162162
contig=r.contig, attributes=attributes_normed
163163
)
164164
if r.end is not None:
165-
df_part = df_part.query(f"start <= {r.end}")
165+
df_part = df_part.loc[df_part["start"] <= r.end]
166166
if r.start is not None:
167-
df_part = df_part.query(f"end >= {r.start}")
167+
df_part = df_part.loc[df_part["end"] >= r.start]
168168
parts.append(df_part)
169169
df = pd.concat(parts, axis=0)
170170
return df.sort_values(["contig", "start"]).reset_index(drop=True).copy()
@@ -192,8 +192,8 @@ def genome_feature_children(
192192
df_gf["Parent"] = df_gf["Parent"].str.split(",")
193193
df_gf = df_gf.explode(column="Parent", ignore_index=True)
194194

195-
# Query to find children of the requested parent.
196-
df_children = df_gf.query(f"Parent == '{parent}'")
195+
# Filter to find children of the requested parent using safe indexing.
196+
df_children = df_gf.loc[df_gf["Parent"] == parent]
197197

198198
return df_children.copy()
199199

@@ -670,7 +670,9 @@ def plot_genes(
670670
def _plot_genes_setup_data(self, *, region):
671671
attributes = [a for a in self._gff_default_attributes if a != "Parent"]
672672
df_genome_features = self.genome_features(region=region, attributes=attributes)
673-
data = df_genome_features.query(f"type == '{self._gff_gene_type}'").copy()
673+
data = df_genome_features.loc[
674+
df_genome_features["type"] == self._gff_gene_type
675+
].copy()
674676
tooltips = [(a.capitalize(), f"@{a}") for a in attributes]
675677
tooltips += [("Location", "@contig:@start{,}-@end{,}")]
676678
return data, tooltips

malariagen_data/anoph/hap_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import zarr # type: ignore
77
from numpydoc_decorator import doc # type: ignore
88

9+
from .safe_query import validate_query
10+
911
from ..util import (
1012
DIM_ALLELE,
1113
DIM_PLOIDY,
@@ -418,7 +420,8 @@ def haplotypes(
418420
df_samples.set_index("sample_id").loc[phased_samples].reset_index()
419421
)
420422

421-
# Apply the query.
423+
# Validate the query to prevent arbitrary code execution (GH-1292).
424+
validate_query(sample_query_prepped)
422425
sample_query_options = sample_query_options or {}
423426
loc_samples = df_samples_phased.eval(
424427
sample_query_prepped, **sample_query_options

malariagen_data/anoph/hapclust.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..util import CacheMiss, _check_types, _pdist_abs_hamming, _pandas_apply
1010
from ..plotly_dendrogram import _plot_dendrogram, concat_clustering_subplots
11+
from .safe_query import validate_query
1112
from . import (
1213
base_params,
1314
plotly_params,
@@ -623,6 +624,7 @@ def transcript_haplotypes(
623624
"""
624625

625626
# Get SNP genotype allele counts for the transcript, applying snp_query
627+
validate_query(snp_query)
626628
df_eff = (
627629
self.snp_effects(
628630
transcript=transcript,

malariagen_data/anoph/karyotype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame:
6262
else:
6363
with importlib.resources.path(resources, self._inversion_tag_path) as path:
6464
df_tag_snps = pd.read_csv(path, sep=",")
65-
return df_tag_snps.query(f"inversion == '{inversion}'").reset_index()
65+
return df_tag_snps.loc[df_tag_snps["inversion"] == inversion].reset_index()
6666

6767
@_check_types
6868
@doc(

malariagen_data/anoph/pca.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(
4242
The following additional parameters were also added in version 8.0.0:
4343
`site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`,
4444
`random_seed`.
45-
4645
""",
4746
parameters=dict(
4847
imputation_method="""
@@ -69,6 +68,10 @@ def pca(
6968
sample_query: Optional[base_params.sample_query] = None,
7069
sample_query_options: Optional[base_params.sample_query_options] = None,
7170
sample_indices: Optional[base_params.sample_indices] = None,
71+
cohorts: Optional[base_params.cohorts] = None,
72+
cohort_size: Optional[base_params.cohort_size] = None,
73+
min_cohort_size: Optional[base_params.min_cohort_size] = None,
74+
max_cohort_size: Optional[base_params.max_cohort_size] = None,
7275
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
7376
site_class: Optional[base_params.site_class] = None,
7477
min_minor_ac: Optional[
@@ -78,9 +81,6 @@ def pca(
7881
base_params.max_missing_an
7982
] = pca_params.max_missing_an_default,
8083
imputation_method: pca_params.imputation_method = pca_params.imputation_method_default,
81-
cohort_size: Optional[base_params.cohort_size] = None,
82-
min_cohort_size: Optional[base_params.min_cohort_size] = None,
83-
max_cohort_size: Optional[base_params.max_cohort_size] = None,
8484
exclude_samples: Optional[base_params.samples] = None,
8585
fit_exclude_samples: Optional[base_params.samples] = None,
8686
random_seed: base_params.random_seed = 42,
@@ -98,8 +98,44 @@ def pca(
9898

9999
## Normalize params for consistent hash value.
100100

101-
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
102-
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
101+
# Handle cohort downsampling.
102+
if cohorts is not None:
103+
if max_cohort_size is None:
104+
raise ValueError(
105+
"`max_cohort_size` is required when `cohorts` is provided."
106+
)
107+
if sample_indices is not None:
108+
raise ValueError(
109+
"Cannot use `sample_indices` with `cohorts` and `max_cohort_size`."
110+
)
111+
if cohort_size is not None or min_cohort_size is not None:
112+
raise ValueError(
113+
"Cannot use `cohort_size` or `min_cohort_size` with `cohorts`."
114+
)
115+
df_samples = self.sample_metadata(
116+
sample_sets=sample_sets,
117+
sample_query=sample_query,
118+
sample_query_options=sample_query_options,
119+
)
120+
# N.B., we are going to overwrite the sample_indices parameter here.
121+
groups = df_samples.groupby(cohorts, sort=False)
122+
ix = []
123+
for _, group in groups:
124+
if len(group) > max_cohort_size:
125+
ix.extend(
126+
group.sample(
127+
n=max_cohort_size, random_state=random_seed, replace=False
128+
).index
129+
)
130+
else:
131+
ix.extend(group.index)
132+
sample_indices = ix
133+
# From this point onwards, the sample_query is no longer needed, because
134+
# the sample selection is defined by the sample_indices.
135+
sample_query = None
136+
sample_query_options = None
137+
138+
# Normalize params for consistent hash value.
103139
(
104140
prepared_sample_sets,
105141
prepared_sample_indices,
@@ -132,6 +168,7 @@ def pca(
132168
max_missing_an=max_missing_an,
133169
imputation_method=imputation_method,
134170
n_components=n_components,
171+
cohorts=cohorts,
135172
cohort_size=cohort_size,
136173
min_cohort_size=min_cohort_size,
137174
max_cohort_size=max_cohort_size,
@@ -149,10 +186,10 @@ def pca(
149186
self.results_cache_set(name=name, params=params, results=results)
150187

151188
# Unpack results.
152-
coords = results["coords"]
153-
evr = results["evr"]
154-
samples = results["samples"]
155-
loc_keep_fit = results["loc_keep_fit"]
189+
coords = np.array(results["coords"])
190+
evr = np.array(results["evr"])
191+
samples = np.array(results["samples"])
192+
loc_keep_fit = np.array(results["loc_keep_fit"])
156193

157194
# Create a new DataFrame containing the PCA coords data.
158195
df_pca = pd.DataFrame(coords, index=samples)
@@ -205,6 +242,7 @@ def _pca(
205242
random_seed,
206243
chunks,
207244
inline_array,
245+
**kwargs,
208246
):
209247
# Load diplotypes.
210248
ds_diplotypes = self.biallelic_diplotypes(

0 commit comments

Comments
 (0)