diff --git a/malariagen_data/anoph/base.py b/malariagen_data/anoph/base.py index dd09226b6..a5293a539 100644 --- a/malariagen_data/anoph/base.py +++ b/malariagen_data/anoph/base.py @@ -28,6 +28,8 @@ from numpydoc_decorator import doc # type: ignore from tqdm.auto import tqdm as tqdm_auto # type: ignore from tqdm.dask import TqdmCallback # type: ignore + +from .safe_query import validate_query from yaspin import yaspin # type: ignore import xarray as xr @@ -980,10 +982,9 @@ def _filter_sample_dataset( # Determine which samples match the sample query. if sample_query != "": - # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean. - loc_samples = df_samples.eval( - sample_query, **sample_query_options, engine="python" - ) + # Validate the query to prevent arbitrary code execution (GH-1292). + validate_query(sample_query) + loc_samples = df_samples.eval(sample_query, **sample_query_options) else: loc_samples = pd.Series(True, index=df_samples.index) diff --git a/malariagen_data/anoph/cnv_frq.py b/malariagen_data/anoph/cnv_frq.py index 84a17100c..ee4724e20 100644 --- a/malariagen_data/anoph/cnv_frq.py +++ b/malariagen_data/anoph/cnv_frq.py @@ -15,6 +15,7 @@ _build_cohorts_from_sample_grouping, _add_frequency_ci, ) +from .safe_query import validate_query from ..util import ( _check_types, _pandas_apply, @@ -671,6 +672,7 @@ def _gene_cnv_frequencies_advanced( debug("apply variant query") if variant_query is not None: + validate_query(variant_query) loc_variants = df_variants.eval(variant_query).values # Convert boolean mask to integer indices for NumPy 2.x compatibility variant_indices = np.where(loc_variants)[0] diff --git a/malariagen_data/anoph/frq_base.py b/malariagen_data/anoph/frq_base.py index 86c6ecba4..2d8fc3963 100644 --- a/malariagen_data/anoph/frq_base.py +++ b/malariagen_data/anoph/frq_base.py @@ -147,8 +147,10 @@ def _build_cohorts_from_sample_grouping( period_str = df_cohorts["period"].astype(str) df_cohorts["label"] = area_str + "_" + taxon_clean + "_" + period_str - # Apply minimum cohort size. - df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True) + # Apply minimum cohort size using safe boolean indexing. + df_cohorts = df_cohorts.loc[df_cohorts["size"] >= min_cohort_size].reset_index( + drop=True + ) # Early check for no cohorts. if len(df_cohorts) == 0: diff --git a/malariagen_data/anoph/genome_features.py b/malariagen_data/anoph/genome_features.py index ed4cc4c39..65a1dfc2b 100644 --- a/malariagen_data/anoph/genome_features.py +++ b/malariagen_data/anoph/genome_features.py @@ -117,8 +117,8 @@ def _genome_features_for_contig(self, *, contig: str, attributes: Tuple[str, ... ) df = self._genome_features(attributes=attributes) - # Apply contig query. - df = df.query(f"contig == '{contig}'") + # Apply contig filter using safe boolean indexing. + df = df.loc[df["contig"] == contig] return df def _prep_gff_attributes( @@ -162,9 +162,9 @@ def genome_features( contig=r.contig, attributes=attributes_normed ) if r.end is not None: - df_part = df_part.query(f"start <= {r.end}") + df_part = df_part.loc[df_part["start"] <= r.end] if r.start is not None: - df_part = df_part.query(f"end >= {r.start}") + df_part = df_part.loc[df_part["end"] >= r.start] parts.append(df_part) df = pd.concat(parts, axis=0) return df.sort_values(["contig", "start"]).reset_index(drop=True).copy() @@ -192,8 +192,8 @@ def genome_feature_children( df_gf["Parent"] = df_gf["Parent"].str.split(",") df_gf = df_gf.explode(column="Parent", ignore_index=True) - # Query to find children of the requested parent. - df_children = df_gf.query(f"Parent == '{parent}'") + # Filter to find children of the requested parent using safe indexing. + df_children = df_gf.loc[df_gf["Parent"] == parent] return df_children.copy() @@ -670,7 +670,9 @@ def plot_genes( def _plot_genes_setup_data(self, *, region): attributes = [a for a in self._gff_default_attributes if a != "Parent"] df_genome_features = self.genome_features(region=region, attributes=attributes) - data = df_genome_features.query(f"type == '{self._gff_gene_type}'").copy() + data = df_genome_features.loc[ + df_genome_features["type"] == self._gff_gene_type + ].copy() tooltips = [(a.capitalize(), f"@{a}") for a in attributes] tooltips += [("Location", "@contig:@start{,}-@end{,}")] return data, tooltips diff --git a/malariagen_data/anoph/hap_data.py b/malariagen_data/anoph/hap_data.py index fe91b7d8b..e0894208a 100644 --- a/malariagen_data/anoph/hap_data.py +++ b/malariagen_data/anoph/hap_data.py @@ -6,6 +6,8 @@ import zarr # type: ignore from numpydoc_decorator import doc # type: ignore +from .safe_query import validate_query + from ..util import ( DIM_ALLELE, DIM_PLOIDY, @@ -418,7 +420,8 @@ def haplotypes( df_samples.set_index("sample_id").loc[phased_samples].reset_index() ) - # Apply the query. + # Validate the query to prevent arbitrary code execution (GH-1292). + validate_query(sample_query_prepped) sample_query_options = sample_query_options or {} loc_samples = df_samples_phased.eval( sample_query_prepped, **sample_query_options diff --git a/malariagen_data/anoph/hapclust.py b/malariagen_data/anoph/hapclust.py index 74636de11..6a7000647 100644 --- a/malariagen_data/anoph/hapclust.py +++ b/malariagen_data/anoph/hapclust.py @@ -8,6 +8,7 @@ from ..util import CacheMiss, _check_types, _pdist_abs_hamming, _pandas_apply from ..plotly_dendrogram import _plot_dendrogram, concat_clustering_subplots +from .safe_query import validate_query from . import ( base_params, plotly_params, @@ -623,6 +624,7 @@ def transcript_haplotypes( """ # Get SNP genotype allele counts for the transcript, applying snp_query + validate_query(snp_query) df_eff = ( self.snp_effects( transcript=transcript, diff --git a/malariagen_data/anoph/karyotype.py b/malariagen_data/anoph/karyotype.py index d0eda0d54..fcfa23f14 100644 --- a/malariagen_data/anoph/karyotype.py +++ b/malariagen_data/anoph/karyotype.py @@ -62,7 +62,7 @@ def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame: else: with importlib.resources.path(resources, self._inversion_tag_path) as path: df_tag_snps = pd.read_csv(path, sep=",") - return df_tag_snps.query(f"inversion == '{inversion}'").reset_index() + return df_tag_snps.loc[df_tag_snps["inversion"] == inversion].reset_index() @_check_types @doc( diff --git a/malariagen_data/anoph/safe_query.py b/malariagen_data/anoph/safe_query.py new file mode 100644 index 000000000..08601c72b --- /dev/null +++ b/malariagen_data/anoph/safe_query.py @@ -0,0 +1,157 @@ +"""Safe query validation for pandas eval/query expressions. + +This module provides AST-based validation of query strings to prevent +arbitrary code execution via pandas DataFrame.eval() and DataFrame.query(). + +Only a restricted subset of Python expressions is allowed: +- Boolean operators: and, or, not +- Comparison operators: ==, !=, <, <=, >, >=, in, not in, is +- Arithmetic operators: +, -, *, /, //, %, ** +- Unary operators: +, -, ~, not +- Constants: strings, numbers, booleans, None +- Names: must match an allowlist of known column names (if provided) +- Parenthesized expressions + +Forbidden constructs include: +- Function calls (e.g., __import__('os')) +- Attribute access (e.g., os.system) +- Subscript/indexing (e.g., x[0]) +- Comprehensions, lambdas, f-strings, starred expressions +- Any identifier containing double underscores (__) +""" + +import ast +import re +from typing import Optional, Set + +# Pattern matching pandas @variable references in query strings. +# These are not valid Python but are a pandas feature for referencing +# local/global variables via the `local_dict` or `global_dict` kwargs. +_AT_VAR_PATTERN = re.compile(r"@([A-Za-z_][A-Za-z0-9_]*)") + + +# AST node types that are safe in query expressions. +_SAFE_NODE_TYPES = ( + ast.Expression, + ast.BoolOp, + ast.BinOp, + ast.UnaryOp, + ast.Compare, + ast.And, + ast.Or, + ast.Not, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, + ast.USub, + ast.UAdd, + ast.Invert, + ast.Eq, + ast.NotEq, + ast.Lt, + ast.LtE, + ast.Gt, + ast.GtE, + ast.In, + ast.NotIn, + ast.Is, + ast.IsNot, + ast.Constant, + ast.Name, + ast.Load, + ast.Tuple, + ast.List, +) + + +class UnsafeQueryError(ValueError): + """Raised when a query string contains unsafe constructs.""" + + pass + + +def _validate_node(node: ast.AST, allowed_names: Optional[Set[str]] = None) -> None: + """Recursively validate that an AST node contains only safe constructs. + + Parameters + ---------- + node : ast.AST + The AST node to validate. + allowed_names : set of str, optional + If provided, restrict identifier names to this set. + + Raises + ------ + UnsafeQueryError + If the node or any of its children contain unsafe constructs. + """ + if not isinstance(node, _SAFE_NODE_TYPES): + raise UnsafeQueryError( + f"Unsafe expression: {type(node).__name__} nodes are not allowed " + f"in query strings. Only comparisons, boolean logic, and constants " + f"are permitted." + ) + + if isinstance(node, ast.Name): + name = node.id + # Block dunder identifiers. + if "__" in name: + raise UnsafeQueryError( + f"Unsafe expression: identifier '{name}' contains double " + f"underscores and is not allowed in query strings." + ) + # Check against allowlist if provided. + if allowed_names is not None and name not in allowed_names: + # Allow common boolean literals that pandas recognizes. + if name not in {"True", "False", "None"}: + raise UnsafeQueryError( + f"Unknown column name '{name}' in query string. " + f"Allowed column names: {sorted(allowed_names)}" + ) + + # Recurse into child nodes. + for child in ast.iter_child_nodes(node): + _validate_node(child, allowed_names) + + +def validate_query(query: str, allowed_names: Optional[Set[str]] = None) -> None: + """Validate that a query string is safe for use with pandas eval/query. + + Parameters + ---------- + query : str + The query string to validate. + allowed_names : set of str, optional + If provided, restrict identifier names to this set of known column + names. If None, any identifier (except those containing ``__``) is + allowed. + + Raises + ------ + UnsafeQueryError + If the query contains unsafe constructs such as function calls, + attribute access, or dunder identifiers. + """ + if not isinstance(query, str): + raise UnsafeQueryError(f"Query must be a string, got {type(query).__name__}.") + + query = query.strip() + if not query: + raise UnsafeQueryError("Query string must not be empty.") + + # Replace pandas @variable references with plain identifiers so the + # expression can be parsed as valid Python. The replaced names are + # prefixed with ``_at_`` to avoid collisions with real column names + # while remaining dunder-free. + query_for_parse = _AT_VAR_PATTERN.sub(r"_at_\1", query) + + try: + tree = ast.parse(query_for_parse, mode="eval") + except SyntaxError as e: + raise UnsafeQueryError(f"Query string is not a valid expression: {e}") from e + + _validate_node(tree, allowed_names) diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 38630f450..19b8db14f 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -24,6 +24,8 @@ import plotly.express as px # type: ignore from numpydoc_decorator import doc # type: ignore +from .safe_query import validate_query + from ..util import _check_types from . import base_params, map_params, plotly_params from .base import AnophelesBase @@ -808,10 +810,9 @@ def sample_metadata( # zero-result queries and provide a helpful warning. df_before_query = df_samples - # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean. - df_samples = df_samples.query( - prepared_sample_query, **sample_query_options, engine="python" - ) + # Validate the query to prevent arbitrary code execution (GH-1292). + validate_query(prepared_sample_query) + df_samples = df_samples.query(prepared_sample_query, **sample_query_options) df_samples = df_samples.reset_index(drop=True) # Warn if query returned zero results on a non-empty dataset. @@ -1197,12 +1198,13 @@ def _prep_sample_selection_cache_params( # Default the sample_query_options to an empty dict. sample_query_options = sample_query_options or {} - # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean. + # Validate the query to prevent arbitrary code execution (GH-1292). # Get the Pandas Series as a NumPy array of Boolean values. # Note: if `prepared_sample_query` is an internal query, this will select all samples, # since `sample_metadata` should have already applied the internal query. + validate_query(prepared_sample_query) loc_samples = df_samples.eval( - prepared_sample_query, **sample_query_options, engine="python" + prepared_sample_query, **sample_query_options ).values # Convert the sample indices to a list. @@ -1368,6 +1370,7 @@ def _setup_sample_symbol( ) data["symbol"] = "" for key, value in symbol.items(): + validate_query(value) data.loc[data.query(value).index, "symbol"] = key symbol_prepped = "symbol" @@ -1421,6 +1424,7 @@ def _setup_sample_colors_plotly( ) data["color"] = "" for key, value in color.items(): + validate_query(value) data.loc[data.query(value).index, "color"] = key color_prepped = "color" @@ -1654,6 +1658,7 @@ def cohorts( self._cache_cohorts[cache_key] = df_cohorts if query is not None: + validate_query(query) df_cohorts = df_cohorts.query(query) df_cohorts = df_cohorts.reset_index(drop=True) @@ -1872,6 +1877,7 @@ def _locate_cohorts(*, cohorts, data, min_cohort_size): for coh, query in cohorts.items(): try: + validate_query(query) loc_coh = data.eval(query).values except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e: raise ValueError( diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py index 40f3f2000..0d4e98840 100644 --- a/malariagen_data/anoph/snp_frq.py +++ b/malariagen_data/anoph/snp_frq.py @@ -15,6 +15,7 @@ _check_types, _pandas_apply, ) +from .safe_query import validate_query from .snp_data import AnophelesSnpData from .frq_base import ( _prep_samples_for_cohort_grouping, @@ -690,6 +691,7 @@ def snp_allele_frequencies_advanced( # Apply variant query. if variant_query is not None: + validate_query(variant_query) loc_variants = np.asarray(df_variants.eval(variant_query)) # Check for no SNPs remaining after applying variant query. @@ -834,6 +836,7 @@ def aa_allele_frequencies_advanced( # Apply variant query if given. if variant_query is not None: + validate_query(variant_query) loc_variants = df_variants.eval(variant_query).values # Check for no SNPs remaining after applying variant query. @@ -923,6 +926,7 @@ def snp_genotype_allele_counts( df_snps = df_snps.loc[loc_sites] if snp_query is not None: + validate_query(snp_query) df_snps = df_snps.query(snp_query) return df_snps diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index f7506c79e..8342dbb88 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -11,6 +11,7 @@ import plotly.graph_objects as go # type: ignore from numpydoc_decorator import doc # type: ignore +from .anoph.safe_query import validate_query from .anoph import ( aim_params, @@ -1297,7 +1298,8 @@ def plot_haplotype_network( # Apply each query in the mapping to create the _partition column for label, query in color.items(): - # Apply the query and assign the label to matching rows + # Validate and apply the query to matching rows + validate_query(query) mask = df_haps.eval(query) df_haps.loc[mask, "_partition"] = label diff --git a/malariagen_data/util.py b/malariagen_data/util.py index aa8b2d6f0..ca7b38677 100644 --- a/malariagen_data/util.py +++ b/malariagen_data/util.py @@ -644,7 +644,7 @@ def _prep_geneset_attributes_arg(attributes): def _handle_region_feature(resource, region): if hasattr(resource, "genome_features"): gene_annotation = resource.genome_features(attributes=["ID"]) - results = gene_annotation.query(f"ID == '{region}'") + results = gene_annotation.loc[gene_annotation["ID"] == region] if not results.empty: # the region is a feature ID feature = results.squeeze() diff --git a/tests/anoph/test_safe_query.py b/tests/anoph/test_safe_query.py new file mode 100644 index 000000000..4ee1d923f --- /dev/null +++ b/tests/anoph/test_safe_query.py @@ -0,0 +1,200 @@ +"""Tests for the safe_query module (GH-1292). + +Ensures that the AST-based query validator correctly accepts safe expressions +and rejects malicious ones that could lead to arbitrary code execution via +pandas DataFrame.eval() / DataFrame.query(). +""" + +import pytest +from typeguard import TypeCheckError + +from malariagen_data.anoph.safe_query import UnsafeQueryError, validate_query + + +class TestValidateQueryAcceptsSafe: + """Ensure legitimate pandas query expressions are accepted.""" + + def test_simple_equality(self): + validate_query("country == 'Ghana'") + + def test_numeric_comparison(self): + validate_query("size >= 10") + + def test_boolean_and(self): + validate_query("country == 'Ghana' and year == 2020") + + def test_boolean_or(self): + validate_query("country == 'Ghana' or country == 'Mali'") + + def test_not_operator(self): + validate_query("not country == 'Ghana'") + + def test_parenthesized_expression(self): + validate_query("(country == 'Ghana') and (year > 2015)") + + def test_in_operator_with_tuple(self): + validate_query("country in ('Ghana', 'Mali', 'Kenya')") + + def test_not_in_operator(self): + validate_query("country not in ('Ghana', 'Mali')") + + def test_complex_boolean_chain(self): + validate_query("country == 'Ghana' and year >= 2015 and taxon == 'gambiae'") + + def test_numeric_arithmetic(self): + validate_query("size + 1 > 10") + + def test_is_comparison(self): + validate_query("value is None") + + def test_is_not_comparison(self): + validate_query("value is not None") + + def test_boolean_literal_true(self): + validate_query("is_surveillance == True") + + def test_boolean_literal_false(self): + validate_query("is_surveillance == False") + + def test_with_allowed_names(self): + validate_query( + "country == 'Ghana'", + allowed_names={"country", "year", "taxon"}, + ) + + def test_inequality_operators(self): + validate_query("year != 2020") + validate_query("size < 100") + validate_query("size <= 100") + validate_query("size > 0") + validate_query("size >= 1") + + def test_unary_minus(self): + validate_query("value > -1") + + def test_list_literal_in(self): + validate_query("country in ['Ghana', 'Mali']") + + def test_whitespace_handling(self): + validate_query(" country == 'Ghana' ") + + def test_at_variable_reference(self): + """Pandas @var syntax for referencing local variables.""" + validate_query("sex_call in @sex_call_list") + + def test_at_variable_in_compound(self): + validate_query("taxon in @taxon_list and year > 2015") + + def test_at_variable_equality(self): + validate_query("country == @target_country") + + +class TestValidateQueryRejectsMalicious: + """Ensure that code injection attempts are blocked.""" + + def test_import_call(self): + with pytest.raises(UnsafeQueryError): + validate_query("__import__('os').system('echo PWNED')") + + def test_import_in_compound_expression(self): + with pytest.raises(UnsafeQueryError): + validate_query( + "__import__('os').system('echo PWNED') or country == 'Ghana'" + ) + + def test_function_call(self): + with pytest.raises(UnsafeQueryError): + validate_query("len(country)") + + def test_attribute_access(self): + with pytest.raises(UnsafeQueryError): + validate_query("country.upper()") + + def test_nested_attribute_access(self): + with pytest.raises(UnsafeQueryError): + validate_query("os.system('id')") + + def test_subscript(self): + with pytest.raises(UnsafeQueryError): + validate_query("country[0]") + + def test_lambda(self): + with pytest.raises(UnsafeQueryError): + validate_query("lambda: 1") + + def test_list_comprehension(self): + with pytest.raises(UnsafeQueryError): + validate_query("[x for x in range(10)]") + + def test_dunder_identifier(self): + with pytest.raises(UnsafeQueryError): + validate_query("__class__") + + def test_dunder_in_name(self): + with pytest.raises(UnsafeQueryError): + validate_query("__builtins__.__import__") + + def test_exec_call(self): + with pytest.raises(UnsafeQueryError): + validate_query("exec('import os')") + + def test_eval_call(self): + with pytest.raises(UnsafeQueryError): + validate_query("eval('1+1')") + + def test_dict_literal(self): + with pytest.raises(UnsafeQueryError): + validate_query("{'key': 'value'}") + + def test_generator_expression(self): + with pytest.raises(UnsafeQueryError): + validate_query("sum(x for x in range(10))") + + def test_starred_expression(self): + with pytest.raises(UnsafeQueryError): + validate_query("*args") + + def test_fstring_attempt(self): + with pytest.raises(UnsafeQueryError): + validate_query("f'{__import__(\"os\")}'") + + def test_walrus_operator(self): + with pytest.raises(UnsafeQueryError): + validate_query("(x := 1)") + + def test_unknown_column_with_allowlist(self): + with pytest.raises(UnsafeQueryError, match="Unknown column name"): + validate_query( + "evil_col == 'value'", + allowed_names={"country", "year"}, + ) + + +class TestValidateQueryEdgeCases: + """Test edge cases and error handling.""" + + def test_empty_string(self): + with pytest.raises(UnsafeQueryError, match="must not be empty"): + validate_query("") + + def test_whitespace_only(self): + with pytest.raises(UnsafeQueryError, match="must not be empty"): + validate_query(" ") + + def test_non_string_input(self): + with pytest.raises((UnsafeQueryError, TypeError, TypeCheckError)): + validate_query(123) + + def test_syntax_error(self): + with pytest.raises(UnsafeQueryError, match="not a valid expression"): + validate_query("country ==") + + def test_multiple_statements(self): + # Multiple statements can't be parsed in eval mode + with pytest.raises(UnsafeQueryError): + validate_query("x = 1; y = 2") + + def test_quote_breaking_attempt(self): + """Ensure quote-breaking in string literals doesn't bypass validation.""" + with pytest.raises(UnsafeQueryError): + validate_query("contig == 'X' or __import__('os').system('id')")