Skip to content

Commit e6ef1cf

Browse files
fix: sanitize user input passed to DataFrame.eval()/query() to prevent code injection (GH-1292)
- Add safe_query.py module with AST-based validation that restricts query expressions to comparisons, boolean logic, constants, and column names - Validate all sample_query, variant_query, and snp_query strings before passing them to pandas eval()/query() - Remove engine="python" from all DataFrame.eval()/query() calls to prevent arbitrary Python execution - Replace f-string interpolation in DataFrame.query() calls with safe boolean indexing (genome_features.py, util.py, frq_base.py, karyotype.py) - Add comprehensive test suite (43 tests) covering safe expressions, malicious payloads, and edge cases Closes #1292
1 parent 7f9f74e commit e6ef1cf

13 files changed

Lines changed: 381 additions & 23 deletions

File tree

malariagen_data/anoph/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from numpydoc_decorator import doc # type: ignore
2929
from tqdm.auto import tqdm as tqdm_auto # type: ignore
3030
from tqdm.dask import TqdmCallback # type: ignore
31+
32+
from .safe_query import validate_query
3133
from yaspin import yaspin # type: ignore
3234
import xarray as xr
3335

@@ -980,10 +982,9 @@ def _filter_sample_dataset(
980982

981983
# Determine which samples match the sample query.
982984
if sample_query != "":
983-
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
984-
loc_samples = df_samples.eval(
985-
sample_query, **sample_query_options, engine="python"
986-
)
985+
# Validate the query to prevent arbitrary code execution (GH-1292).
986+
validate_query(sample_query)
987+
loc_samples = df_samples.eval(sample_query, **sample_query_options)
987988
else:
988989
loc_samples = pd.Series(True, index=df_samples.index)
989990

malariagen_data/anoph/cnv_frq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_build_cohorts_from_sample_grouping,
1616
_add_frequency_ci,
1717
)
18+
from .safe_query import validate_query
1819
from ..util import (
1920
_check_types,
2021
_pandas_apply,
@@ -671,6 +672,7 @@ def _gene_cnv_frequencies_advanced(
671672

672673
debug("apply variant query")
673674
if variant_query is not None:
675+
validate_query(variant_query)
674676
loc_variants = df_variants.eval(variant_query).values
675677
# Convert boolean mask to integer indices for NumPy 2.x compatibility
676678
variant_indices = np.where(loc_variants)[0]

malariagen_data/anoph/frq_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def _build_cohorts_from_sample_grouping(
147147
period_str = df_cohorts["period"].astype(str)
148148
df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str
149149

150-
# Apply minimum cohort size.
151-
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
150+
# Apply minimum cohort size using safe boolean indexing.
151+
df_cohorts = df_cohorts.loc[df_cohorts["size"] >= min_cohort_size].reset_index(
152+
drop=True
153+
)
152154

153155
# Early check for no cohorts.
154156
if len(df_cohorts) == 0:

malariagen_data/anoph/genome_features.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def _genome_features_for_contig(self, *, contig: str, attributes: Tuple[str, ...
117117
)
118118
df = self._genome_features(attributes=attributes)
119119

120-
# Apply contig query.
121-
df = df.query(f"contig == '{contig}'")
120+
# Apply contig filter using safe boolean indexing.
121+
df = df.loc[df["contig"] == contig]
122122
return df
123123

124124
def _prep_gff_attributes(
@@ -162,9 +162,9 @@ def genome_features(
162162
contig=r.contig, attributes=attributes_normed
163163
)
164164
if r.end is not None:
165-
df_part = df_part.query(f"start <= {r.end}")
165+
df_part = df_part.loc[df_part["start"] <= r.end]
166166
if r.start is not None:
167-
df_part = df_part.query(f"end >= {r.start}")
167+
df_part = df_part.loc[df_part["end"] >= r.start]
168168
parts.append(df_part)
169169
df = pd.concat(parts, axis=0)
170170
return df.sort_values(["contig", "start"]).reset_index(drop=True).copy()
@@ -192,8 +192,8 @@ def genome_feature_children(
192192
df_gf["Parent"] = df_gf["Parent"].str.split(",")
193193
df_gf = df_gf.explode(column="Parent", ignore_index=True)
194194

195-
# Query to find children of the requested parent.
196-
df_children = df_gf.query(f"Parent == '{parent}'")
195+
# Filter to find children of the requested parent using safe indexing.
196+
df_children = df_gf.loc[df_gf["Parent"] == parent]
197197

198198
return df_children.copy()
199199

@@ -670,7 +670,9 @@ def plot_genes(
670670
def _plot_genes_setup_data(self, *, region):
671671
attributes = [a for a in self._gff_default_attributes if a != "Parent"]
672672
df_genome_features = self.genome_features(region=region, attributes=attributes)
673-
data = df_genome_features.query(f"type == '{self._gff_gene_type}'").copy()
673+
data = df_genome_features.loc[
674+
df_genome_features["type"] == self._gff_gene_type
675+
].copy()
674676
tooltips = [(a.capitalize(), f"@{a}") for a in attributes]
675677
tooltips += [("Location", "@contig:@start{,}-@end{,}")]
676678
return data, tooltips

malariagen_data/anoph/hap_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import zarr # type: ignore
77
from numpydoc_decorator import doc # type: ignore
88

9+
from .safe_query import validate_query
10+
911
from ..util import (
1012
DIM_ALLELE,
1113
DIM_PLOIDY,
@@ -418,7 +420,8 @@ def haplotypes(
418420
df_samples.set_index("sample_id").loc[phased_samples].reset_index()
419421
)
420422

421-
# Apply the query.
423+
# Validate the query to prevent arbitrary code execution (GH-1292).
424+
validate_query(sample_query_prepped)
422425
sample_query_options = sample_query_options or {}
423426
loc_samples = df_samples_phased.eval(
424427
sample_query_prepped, **sample_query_options

malariagen_data/anoph/hapclust.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..util import CacheMiss, _check_types, _pdist_abs_hamming, _pandas_apply
1010
from ..plotly_dendrogram import _plot_dendrogram, concat_clustering_subplots
11+
from .safe_query import validate_query
1112
from . import (
1213
base_params,
1314
plotly_params,
@@ -623,6 +624,7 @@ def transcript_haplotypes(
623624
"""
624625

625626
# Get SNP genotype allele counts for the transcript, applying snp_query
627+
validate_query(snp_query)
626628
df_eff = (
627629
self.snp_effects(
628630
transcript=transcript,

malariagen_data/anoph/karyotype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame:
6262
else:
6363
with importlib.resources.path(resources, self._inversion_tag_path) as path:
6464
df_tag_snps = pd.read_csv(path, sep=",")
65-
return df_tag_snps.query(f"inversion == '{inversion}'").reset_index()
65+
return df_tag_snps.loc[df_tag_snps["inversion"] == inversion].reset_index()
6666

6767
@_check_types
6868
@doc(
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Safe query validation for pandas eval/query expressions.
2+
3+
This module provides AST-based validation of query strings to prevent
4+
arbitrary code execution via pandas DataFrame.eval() and DataFrame.query().
5+
6+
Only a restricted subset of Python expressions is allowed:
7+
- Boolean operators: and, or, not
8+
- Comparison operators: ==, !=, <, <=, >, >=, in, not in, is
9+
- Arithmetic operators: +, -, *, /, //, %, **
10+
- Unary operators: +, -, ~, not
11+
- Constants: strings, numbers, booleans, None
12+
- Names: must match an allowlist of known column names (if provided)
13+
- Parenthesized expressions
14+
15+
Forbidden constructs include:
16+
- Function calls (e.g., __import__('os'))
17+
- Attribute access (e.g., os.system)
18+
- Subscript/indexing (e.g., x[0])
19+
- Comprehensions, lambdas, f-strings, starred expressions
20+
- Any identifier containing double underscores (__)
21+
"""
22+
23+
import ast
24+
from typing import Optional, Set
25+
26+
27+
# AST node types that are safe in query expressions.
28+
_SAFE_NODE_TYPES = (
29+
ast.Expression,
30+
ast.BoolOp,
31+
ast.BinOp,
32+
ast.UnaryOp,
33+
ast.Compare,
34+
ast.And,
35+
ast.Or,
36+
ast.Not,
37+
ast.Add,
38+
ast.Sub,
39+
ast.Mult,
40+
ast.Div,
41+
ast.FloorDiv,
42+
ast.Mod,
43+
ast.Pow,
44+
ast.USub,
45+
ast.UAdd,
46+
ast.Invert,
47+
ast.Eq,
48+
ast.NotEq,
49+
ast.Lt,
50+
ast.LtE,
51+
ast.Gt,
52+
ast.GtE,
53+
ast.In,
54+
ast.NotIn,
55+
ast.Is,
56+
ast.IsNot,
57+
ast.Constant,
58+
ast.Name,
59+
ast.Load,
60+
ast.Tuple,
61+
ast.List,
62+
)
63+
64+
65+
class UnsafeQueryError(ValueError):
66+
"""Raised when a query string contains unsafe constructs."""
67+
68+
pass
69+
70+
71+
def _validate_node(node: ast.AST, allowed_names: Optional[Set[str]] = None) -> None:
72+
"""Recursively validate that an AST node contains only safe constructs.
73+
74+
Parameters
75+
----------
76+
node : ast.AST
77+
The AST node to validate.
78+
allowed_names : set of str, optional
79+
If provided, restrict identifier names to this set.
80+
81+
Raises
82+
------
83+
UnsafeQueryError
84+
If the node or any of its children contain unsafe constructs.
85+
"""
86+
if not isinstance(node, _SAFE_NODE_TYPES):
87+
raise UnsafeQueryError(
88+
f"Unsafe expression: {type(node).__name__} nodes are not allowed "
89+
f"in query strings. Only comparisons, boolean logic, and constants "
90+
f"are permitted."
91+
)
92+
93+
if isinstance(node, ast.Name):
94+
name = node.id
95+
# Block dunder identifiers.
96+
if "__" in name:
97+
raise UnsafeQueryError(
98+
f"Unsafe expression: identifier '{name}' contains double "
99+
f"underscores and is not allowed in query strings."
100+
)
101+
# Check against allowlist if provided.
102+
if allowed_names is not None and name not in allowed_names:
103+
# Allow common boolean literals that pandas recognizes.
104+
if name not in {"True", "False", "None"}:
105+
raise UnsafeQueryError(
106+
f"Unknown column name '{name}' in query string. "
107+
f"Allowed column names: {sorted(allowed_names)}"
108+
)
109+
110+
# Recurse into child nodes.
111+
for child in ast.iter_child_nodes(node):
112+
_validate_node(child, allowed_names)
113+
114+
115+
def validate_query(query: str, allowed_names: Optional[Set[str]] = None) -> None:
116+
"""Validate that a query string is safe for use with pandas eval/query.
117+
118+
Parameters
119+
----------
120+
query : str
121+
The query string to validate.
122+
allowed_names : set of str, optional
123+
If provided, restrict identifier names to this set of known column
124+
names. If None, any identifier (except those containing ``__``) is
125+
allowed.
126+
127+
Raises
128+
------
129+
UnsafeQueryError
130+
If the query contains unsafe constructs such as function calls,
131+
attribute access, or dunder identifiers.
132+
"""
133+
if not isinstance(query, str):
134+
raise UnsafeQueryError(f"Query must be a string, got {type(query).__name__}.")
135+
136+
query = query.strip()
137+
if not query:
138+
raise UnsafeQueryError("Query string must not be empty.")
139+
140+
try:
141+
tree = ast.parse(query, mode="eval")
142+
except SyntaxError as e:
143+
raise UnsafeQueryError(f"Query string is not a valid expression: {e}") from e
144+
145+
_validate_node(tree, allowed_names)

malariagen_data/anoph/sample_metadata.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import plotly.express as px # type: ignore
2525
from numpydoc_decorator import doc # type: ignore
2626

27+
from .safe_query import validate_query
28+
2729
from ..util import _check_types
2830
from . import base_params, map_params, plotly_params
2931
from .base import AnophelesBase
@@ -808,10 +810,9 @@ def sample_metadata(
808810
# zero-result queries and provide a helpful warning.
809811
df_before_query = df_samples
810812

811-
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
812-
df_samples = df_samples.query(
813-
prepared_sample_query, **sample_query_options, engine="python"
814-
)
813+
# Validate the query to prevent arbitrary code execution (GH-1292).
814+
validate_query(prepared_sample_query)
815+
df_samples = df_samples.query(prepared_sample_query, **sample_query_options)
815816
df_samples = df_samples.reset_index(drop=True)
816817

817818
# Warn if query returned zero results on a non-empty dataset.
@@ -1197,12 +1198,13 @@ def _prep_sample_selection_cache_params(
11971198
# Default the sample_query_options to an empty dict.
11981199
sample_query_options = sample_query_options or {}
11991200

1200-
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
1201+
# Validate the query to prevent arbitrary code execution (GH-1292).
12011202
# Get the Pandas Series as a NumPy array of Boolean values.
12021203
# Note: if `prepared_sample_query` is an internal query, this will select all samples,
12031204
# since `sample_metadata` should have already applied the internal query.
1205+
validate_query(prepared_sample_query)
12041206
loc_samples = df_samples.eval(
1205-
prepared_sample_query, **sample_query_options, engine="python"
1207+
prepared_sample_query, **sample_query_options
12061208
).values
12071209

12081210
# Convert the sample indices to a list.
@@ -1368,6 +1370,7 @@ def _setup_sample_symbol(
13681370
)
13691371
data["symbol"] = ""
13701372
for key, value in symbol.items():
1373+
validate_query(value)
13711374
data.loc[data.query(value).index, "symbol"] = key
13721375
symbol_prepped = "symbol"
13731376

@@ -1421,6 +1424,7 @@ def _setup_sample_colors_plotly(
14211424
)
14221425
data["color"] = ""
14231426
for key, value in color.items():
1427+
validate_query(value)
14241428
data.loc[data.query(value).index, "color"] = key
14251429
color_prepped = "color"
14261430

@@ -1654,6 +1658,7 @@ def cohorts(
16541658
self._cache_cohorts[cache_key] = df_cohorts
16551659

16561660
if query is not None:
1661+
validate_query(query)
16571662
df_cohorts = df_cohorts.query(query)
16581663
df_cohorts = df_cohorts.reset_index(drop=True)
16591664

@@ -1872,6 +1877,7 @@ def _locate_cohorts(*, cohorts, data, min_cohort_size):
18721877

18731878
for coh, query in cohorts.items():
18741879
try:
1880+
validate_query(query)
18751881
loc_coh = data.eval(query).values
18761882
except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e:
18771883
raise ValueError(

malariagen_data/anoph/snp_frq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_check_types,
1616
_pandas_apply,
1717
)
18+
from .safe_query import validate_query
1819
from .snp_data import AnophelesSnpData
1920
from .frq_base import (
2021
_prep_samples_for_cohort_grouping,
@@ -690,6 +691,7 @@ def snp_allele_frequencies_advanced(
690691

691692
# Apply variant query.
692693
if variant_query is not None:
694+
validate_query(variant_query)
693695
loc_variants = np.asarray(df_variants.eval(variant_query))
694696

695697
# Check for no SNPs remaining after applying variant query.
@@ -834,6 +836,7 @@ def aa_allele_frequencies_advanced(
834836

835837
# Apply variant query if given.
836838
if variant_query is not None:
839+
validate_query(variant_query)
837840
loc_variants = df_variants.eval(variant_query).values
838841

839842
# Check for no SNPs remaining after applying variant query.
@@ -923,6 +926,7 @@ def snp_genotype_allele_counts(
923926
df_snps = df_snps.loc[loc_sites]
924927

925928
if snp_query is not None:
929+
validate_query(snp_query)
926930
df_snps = df_snps.query(snp_query)
927931

928932
return df_snps

0 commit comments

Comments
 (0)