Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions malariagen_data/anoph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from numpydoc_decorator import doc # type: ignore
from tqdm.auto import tqdm as tqdm_auto # type: ignore
from tqdm.dask import TqdmCallback # type: ignore

from .safe_query import validate_query
from yaspin import yaspin # type: ignore
import xarray as xr

Expand Down Expand Up @@ -980,10 +982,9 @@ def _filter_sample_dataset(

# Determine which samples match the sample query.
if sample_query != "":
# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
loc_samples = df_samples.eval(
sample_query, **sample_query_options, engine="python"
)
# Validate the query to prevent arbitrary code execution (GH-1292).
validate_query(sample_query)
loc_samples = df_samples.eval(sample_query, **sample_query_options)
else:
loc_samples = pd.Series(True, index=df_samples.index)

Expand Down
2 changes: 2 additions & 0 deletions malariagen_data/anoph/cnv_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_build_cohorts_from_sample_grouping,
_add_frequency_ci,
)
from .safe_query import validate_query
from ..util import (
_check_types,
_pandas_apply,
Expand Down Expand Up @@ -671,6 +672,7 @@ def _gene_cnv_frequencies_advanced(

debug("apply variant query")
if variant_query is not None:
validate_query(variant_query)
loc_variants = df_variants.eval(variant_query).values
# Convert boolean mask to integer indices for NumPy 2.x compatibility
variant_indices = np.where(loc_variants)[0]
Expand Down
6 changes: 4 additions & 2 deletions malariagen_data/anoph/frq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ def _build_cohorts_from_sample_grouping(
period_str = df_cohorts["period"].astype(str)
df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str

# Apply minimum cohort size.
df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
# Apply minimum cohort size using safe boolean indexing.
df_cohorts = df_cohorts.loc[df_cohorts["size"] >= min_cohort_size].reset_index(
drop=True
)

# Early check for no cohorts.
if len(df_cohorts) == 0:
Expand Down
16 changes: 9 additions & 7 deletions malariagen_data/anoph/genome_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def _genome_features_for_contig(self, *, contig: str, attributes: Tuple[str, ...
)
df = self._genome_features(attributes=attributes)

# Apply contig query.
df = df.query(f"contig == '{contig}'")
# Apply contig filter using safe boolean indexing.
df = df.loc[df["contig"] == contig]
return df

def _prep_gff_attributes(
Expand Down Expand Up @@ -162,9 +162,9 @@ def genome_features(
contig=r.contig, attributes=attributes_normed
)
if r.end is not None:
df_part = df_part.query(f"start <= {r.end}")
df_part = df_part.loc[df_part["start"] <= r.end]
if r.start is not None:
df_part = df_part.query(f"end >= {r.start}")
df_part = df_part.loc[df_part["end"] >= r.start]
parts.append(df_part)
df = pd.concat(parts, axis=0)
return df.sort_values(["contig", "start"]).reset_index(drop=True).copy()
Expand Down Expand Up @@ -192,8 +192,8 @@ def genome_feature_children(
df_gf["Parent"] = df_gf["Parent"].str.split(",")
df_gf = df_gf.explode(column="Parent", ignore_index=True)

# Query to find children of the requested parent.
df_children = df_gf.query(f"Parent == '{parent}'")
# Filter to find children of the requested parent using safe indexing.
df_children = df_gf.loc[df_gf["Parent"] == parent]

return df_children.copy()

Expand Down Expand Up @@ -670,7 +670,9 @@ def plot_genes(
def _plot_genes_setup_data(self, *, region):
attributes = [a for a in self._gff_default_attributes if a != "Parent"]
df_genome_features = self.genome_features(region=region, attributes=attributes)
data = df_genome_features.query(f"type == '{self._gff_gene_type}'").copy()
data = df_genome_features.loc[
df_genome_features["type"] == self._gff_gene_type
].copy()
tooltips = [(a.capitalize(), f"@{a}") for a in attributes]
tooltips += [("Location", "@contig:@start{,}-@end{,}")]
return data, tooltips
Expand Down
5 changes: 4 additions & 1 deletion malariagen_data/anoph/hap_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import zarr # type: ignore
from numpydoc_decorator import doc # type: ignore

from .safe_query import validate_query

from ..util import (
DIM_ALLELE,
DIM_PLOIDY,
Expand Down Expand Up @@ -418,7 +420,8 @@ def haplotypes(
df_samples.set_index("sample_id").loc[phased_samples].reset_index()
)

# Apply the query.
# Validate the query to prevent arbitrary code execution (GH-1292).
validate_query(sample_query_prepped)
sample_query_options = sample_query_options or {}
loc_samples = df_samples_phased.eval(
sample_query_prepped, **sample_query_options
Expand Down
2 changes: 2 additions & 0 deletions malariagen_data/anoph/hapclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..util import CacheMiss, _check_types, _pdist_abs_hamming, _pandas_apply
from ..plotly_dendrogram import _plot_dendrogram, concat_clustering_subplots
from .safe_query import validate_query
from . import (
base_params,
plotly_params,
Expand Down Expand Up @@ -623,6 +624,7 @@ def transcript_haplotypes(
"""

# Get SNP genotype allele counts for the transcript, applying snp_query
validate_query(snp_query)
df_eff = (
self.snp_effects(
transcript=transcript,
Expand Down
2 changes: 1 addition & 1 deletion malariagen_data/anoph/karyotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame:
else:
with importlib.resources.path(resources, self._inversion_tag_path) as path:
df_tag_snps = pd.read_csv(path, sep=",")
return df_tag_snps.query(f"inversion == '{inversion}'").reset_index()
return df_tag_snps.loc[df_tag_snps["inversion"] == inversion].reset_index()

@_check_types
@doc(
Expand Down
145 changes: 145 additions & 0 deletions malariagen_data/anoph/safe_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Safe query validation for pandas eval/query expressions.

This module provides AST-based validation of query strings to prevent
arbitrary code execution via pandas DataFrame.eval() and DataFrame.query().

Only a restricted subset of Python expressions is allowed:
- Boolean operators: and, or, not
- Comparison operators: ==, !=, <, <=, >, >=, in, not in, is
- Arithmetic operators: +, -, *, /, //, %, **
- Unary operators: +, -, ~, not
- Constants: strings, numbers, booleans, None
- Names: must match an allowlist of known column names (if provided)
- Parenthesized expressions

Forbidden constructs include:
- Function calls (e.g., __import__('os'))
- Attribute access (e.g., os.system)
- Subscript/indexing (e.g., x[0])
- Comprehensions, lambdas, f-strings, starred expressions
- Any identifier containing double underscores (__)
"""

import ast
from typing import Optional, Set


# AST node types that are safe in query expressions.
_SAFE_NODE_TYPES = (
ast.Expression,
ast.BoolOp,
ast.BinOp,
ast.UnaryOp,
ast.Compare,
ast.And,
ast.Or,
ast.Not,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.FloorDiv,
ast.Mod,
ast.Pow,
ast.USub,
ast.UAdd,
ast.Invert,
ast.Eq,
ast.NotEq,
ast.Lt,
ast.LtE,
ast.Gt,
ast.GtE,
ast.In,
ast.NotIn,
ast.Is,
ast.IsNot,
ast.Constant,
ast.Name,
ast.Load,
ast.Tuple,
ast.List,
)


class UnsafeQueryError(ValueError):
"""Raised when a query string contains unsafe constructs."""

pass


def _validate_node(node: ast.AST, allowed_names: Optional[Set[str]] = None) -> None:
"""Recursively validate that an AST node contains only safe constructs.

Parameters
----------
node : ast.AST
The AST node to validate.
allowed_names : set of str, optional
If provided, restrict identifier names to this set.

Raises
------
UnsafeQueryError
If the node or any of its children contain unsafe constructs.
"""
if not isinstance(node, _SAFE_NODE_TYPES):
raise UnsafeQueryError(
f"Unsafe expression: {type(node).__name__} nodes are not allowed "
f"in query strings. Only comparisons, boolean logic, and constants "
f"are permitted."
Comment on lines +95 to +96
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generic error message for disallowed AST nodes says "Only comparisons, boolean logic, and constants are permitted." but this validator also explicitly allows arithmetic (ast.BinOp with +, -, *, etc.) and unary operators. Updating the message to reflect the actual allowlist would avoid confusing users when they hit a validation error.

Suggested change
f"in query strings. Only comparisons, boolean logic, and constants "
f"are permitted."
f"in query strings. Only comparisons, boolean logic, arithmetic, "
f"unary operations, names, and constants are permitted."

Copilot uses AI. Check for mistakes.
)

if isinstance(node, ast.Name):
name = node.id
# Block dunder identifiers.
if "__" in name:
raise UnsafeQueryError(
f"Unsafe expression: identifier '{name}' contains double "
f"underscores and is not allowed in query strings."
)
# Check against allowlist if provided.
if allowed_names is not None and name not in allowed_names:
# Allow common boolean literals that pandas recognizes.
if name not in {"True", "False", "None"}:
raise UnsafeQueryError(
f"Unknown column name '{name}' in query string. "
f"Allowed column names: {sorted(allowed_names)}"
)

# Recurse into child nodes.
for child in ast.iter_child_nodes(node):
_validate_node(child, allowed_names)


def validate_query(query: str, allowed_names: Optional[Set[str]] = None) -> None:
"""Validate that a query string is safe for use with pandas eval/query.

Parameters
----------
query : str
The query string to validate.
allowed_names : set of str, optional
If provided, restrict identifier names to this set of known column
names. If None, any identifier (except those containing ``__``) is
allowed.

Raises
------
UnsafeQueryError
If the query contains unsafe constructs such as function calls,
attribute access, or dunder identifiers.
"""
if not isinstance(query, str):
raise UnsafeQueryError(f"Query must be a string, got {type(query).__name__}.")

query = query.strip()
if not query:
raise UnsafeQueryError("Query string must not be empty.")

try:
tree = ast.parse(query, mode="eval")
except SyntaxError as e:
raise UnsafeQueryError(f"Query string is not a valid expression: {e}") from e

_validate_node(tree, allowed_names)
18 changes: 12 additions & 6 deletions malariagen_data/anoph/sample_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import plotly.express as px # type: ignore
from numpydoc_decorator import doc # type: ignore

from .safe_query import validate_query

from ..util import _check_types
from . import base_params, map_params, plotly_params
from .base import AnophelesBase
Expand Down Expand Up @@ -808,10 +810,9 @@ def sample_metadata(
# zero-result queries and provide a helpful warning.
df_before_query = df_samples

# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
df_samples = df_samples.query(
prepared_sample_query, **sample_query_options, engine="python"
)
# Validate the query to prevent arbitrary code execution (GH-1292).
validate_query(prepared_sample_query)
df_samples = df_samples.query(prepared_sample_query, **sample_query_options)
df_samples = df_samples.reset_index(drop=True)

# Warn if query returned zero results on a non-empty dataset.
Expand Down Expand Up @@ -1197,12 +1198,13 @@ def _prep_sample_selection_cache_params(
# Default the sample_query_options to an empty dict.
sample_query_options = sample_query_options or {}

# Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean.
# Validate the query to prevent arbitrary code execution (GH-1292).
# Get the Pandas Series as a NumPy array of Boolean values.
# Note: if `prepared_sample_query` is an internal query, this will select all samples,
# since `sample_metadata` should have already applied the internal query.
validate_query(prepared_sample_query)
loc_samples = df_samples.eval(
prepared_sample_query, **sample_query_options, engine="python"
prepared_sample_query, **sample_query_options
).values

# Convert the sample indices to a list.
Expand Down Expand Up @@ -1368,6 +1370,7 @@ def _setup_sample_symbol(
)
data["symbol"] = ""
for key, value in symbol.items():
validate_query(value)
data.loc[data.query(value).index, "symbol"] = key
symbol_prepped = "symbol"

Expand Down Expand Up @@ -1421,6 +1424,7 @@ def _setup_sample_colors_plotly(
)
data["color"] = ""
for key, value in color.items():
validate_query(value)
data.loc[data.query(value).index, "color"] = key
color_prepped = "color"

Expand Down Expand Up @@ -1654,6 +1658,7 @@ def cohorts(
self._cache_cohorts[cache_key] = df_cohorts

if query is not None:
validate_query(query)
df_cohorts = df_cohorts.query(query)
df_cohorts = df_cohorts.reset_index(drop=True)

Expand Down Expand Up @@ -1872,6 +1877,7 @@ def _locate_cohorts(*, cohorts, data, min_cohort_size):

for coh, query in cohorts.items():
try:
validate_query(query)
loc_coh = data.eval(query).values
except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e:
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_query(query) can raise UnsafeQueryError (a ValueError), but the surrounding try/except only catches KeyError, NameError, SyntaxError, TypeError, AttributeError. As a result, unsafe queries will bypass this wrapper and won’t be re-raised with the cohort context (Invalid query for cohort ...). Consider catching UnsafeQueryError here (or broadening to include ValueError) so users still get the cohort-specific error message.

Suggested change
except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e:
except (
KeyError,
NameError,
SyntaxError,
TypeError,
AttributeError,
ValueError,
) as e:

Copilot uses AI. Check for mistakes.
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions malariagen_data/anoph/snp_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_check_types,
_pandas_apply,
)
from .safe_query import validate_query
from .snp_data import AnophelesSnpData
from .frq_base import (
_prep_samples_for_cohort_grouping,
Expand Down Expand Up @@ -690,6 +691,7 @@ def snp_allele_frequencies_advanced(

# Apply variant query.
if variant_query is not None:
validate_query(variant_query)
loc_variants = np.asarray(df_variants.eval(variant_query))

# Check for no SNPs remaining after applying variant query.
Expand Down Expand Up @@ -834,6 +836,7 @@ def aa_allele_frequencies_advanced(

# Apply variant query if given.
if variant_query is not None:
validate_query(variant_query)
loc_variants = df_variants.eval(variant_query).values

# Check for no SNPs remaining after applying variant query.
Expand Down Expand Up @@ -923,6 +926,7 @@ def snp_genotype_allele_counts(
df_snps = df_snps.loc[loc_sites]

if snp_query is not None:
validate_query(snp_query)
df_snps = df_snps.query(snp_query)

return df_snps
Expand Down
Loading
Loading