-
Notifications
You must be signed in to change notification settings - Fork 178
Fix: Sanitize user input passed to DataFrame.eval()/query() to prevent code injection #1293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e6ef1cf
6fafe9b
90c58aa
4093e67
fd1960b
4d8eb65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| """Safe query validation for pandas eval/query expressions. | ||
|
|
||
| This module provides AST-based validation of query strings to prevent | ||
| arbitrary code execution via pandas DataFrame.eval() and DataFrame.query(). | ||
|
|
||
| Only a restricted subset of Python expressions is allowed: | ||
| - Boolean operators: and, or, not | ||
| - Comparison operators: ==, !=, <, <=, >, >=, in, not in, is | ||
| - Arithmetic operators: +, -, *, /, //, %, ** | ||
| - Unary operators: +, -, ~, not | ||
| - Constants: strings, numbers, booleans, None | ||
| - Names: must match an allowlist of known column names (if provided) | ||
| - Parenthesized expressions | ||
|
|
||
| Forbidden constructs include: | ||
| - Function calls (e.g., __import__('os')) | ||
| - Attribute access (e.g., os.system) | ||
| - Subscript/indexing (e.g., x[0]) | ||
| - Comprehensions, lambdas, f-strings, starred expressions | ||
| - Any identifier containing double underscores (__) | ||
| """ | ||
|
|
||
| import ast | ||
| from typing import Optional, Set | ||
|
|
||
|
|
||
| # AST node types that are safe in query expressions. | ||
| _SAFE_NODE_TYPES = ( | ||
| ast.Expression, | ||
| ast.BoolOp, | ||
| ast.BinOp, | ||
| ast.UnaryOp, | ||
| ast.Compare, | ||
| ast.And, | ||
| ast.Or, | ||
| ast.Not, | ||
| ast.Add, | ||
| ast.Sub, | ||
| ast.Mult, | ||
| ast.Div, | ||
| ast.FloorDiv, | ||
| ast.Mod, | ||
| ast.Pow, | ||
| ast.USub, | ||
| ast.UAdd, | ||
| ast.Invert, | ||
| ast.Eq, | ||
| ast.NotEq, | ||
| ast.Lt, | ||
| ast.LtE, | ||
| ast.Gt, | ||
| ast.GtE, | ||
| ast.In, | ||
| ast.NotIn, | ||
| ast.Is, | ||
| ast.IsNot, | ||
| ast.Constant, | ||
| ast.Name, | ||
| ast.Load, | ||
| ast.Tuple, | ||
| ast.List, | ||
| ) | ||
|
|
||
|
|
||
| class UnsafeQueryError(ValueError): | ||
| """Raised when a query string contains unsafe constructs.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| def _validate_node(node: ast.AST, allowed_names: Optional[Set[str]] = None) -> None: | ||
| """Recursively validate that an AST node contains only safe constructs. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| node : ast.AST | ||
| The AST node to validate. | ||
| allowed_names : set of str, optional | ||
| If provided, restrict identifier names to this set. | ||
|
|
||
| Raises | ||
| ------ | ||
| UnsafeQueryError | ||
| If the node or any of its children contain unsafe constructs. | ||
| """ | ||
| if not isinstance(node, _SAFE_NODE_TYPES): | ||
| raise UnsafeQueryError( | ||
| f"Unsafe expression: {type(node).__name__} nodes are not allowed " | ||
| f"in query strings. Only comparisons, boolean logic, and constants " | ||
| f"are permitted." | ||
| ) | ||
|
|
||
| if isinstance(node, ast.Name): | ||
| name = node.id | ||
| # Block dunder identifiers. | ||
| if "__" in name: | ||
| raise UnsafeQueryError( | ||
| f"Unsafe expression: identifier '{name}' contains double " | ||
| f"underscores and is not allowed in query strings." | ||
| ) | ||
| # Check against allowlist if provided. | ||
| if allowed_names is not None and name not in allowed_names: | ||
| # Allow common boolean literals that pandas recognizes. | ||
| if name not in {"True", "False", "None"}: | ||
| raise UnsafeQueryError( | ||
| f"Unknown column name '{name}' in query string. " | ||
| f"Allowed column names: {sorted(allowed_names)}" | ||
| ) | ||
|
|
||
| # Recurse into child nodes. | ||
| for child in ast.iter_child_nodes(node): | ||
| _validate_node(child, allowed_names) | ||
|
|
||
|
|
||
| def validate_query(query: str, allowed_names: Optional[Set[str]] = None) -> None: | ||
| """Validate that a query string is safe for use with pandas eval/query. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| query : str | ||
| The query string to validate. | ||
| allowed_names : set of str, optional | ||
| If provided, restrict identifier names to this set of known column | ||
| names. If None, any identifier (except those containing ``__``) is | ||
| allowed. | ||
|
|
||
| Raises | ||
| ------ | ||
| UnsafeQueryError | ||
| If the query contains unsafe constructs such as function calls, | ||
| attribute access, or dunder identifiers. | ||
| """ | ||
| if not isinstance(query, str): | ||
| raise UnsafeQueryError(f"Query must be a string, got {type(query).__name__}.") | ||
|
|
||
| query = query.strip() | ||
| if not query: | ||
| raise UnsafeQueryError("Query string must not be empty.") | ||
|
|
||
| try: | ||
| tree = ast.parse(query, mode="eval") | ||
| except SyntaxError as e: | ||
| raise UnsafeQueryError(f"Query string is not a valid expression: {e}") from e | ||
|
|
||
| _validate_node(tree, allowed_names) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,6 +24,8 @@ | |||||||||||||||||||
| import plotly.express as px # type: ignore | ||||||||||||||||||||
| from numpydoc_decorator import doc # type: ignore | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from .safe_query import validate_query | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from ..util import _check_types | ||||||||||||||||||||
| from . import base_params, map_params, plotly_params | ||||||||||||||||||||
| from .base import AnophelesBase | ||||||||||||||||||||
|
|
@@ -808,10 +810,9 @@ def sample_metadata( | |||||||||||||||||||
| # zero-result queries and provide a helpful warning. | ||||||||||||||||||||
| df_before_query = df_samples | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean. | ||||||||||||||||||||
| df_samples = df_samples.query( | ||||||||||||||||||||
| prepared_sample_query, **sample_query_options, engine="python" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| # Validate the query to prevent arbitrary code execution (GH-1292). | ||||||||||||||||||||
| validate_query(prepared_sample_query) | ||||||||||||||||||||
| df_samples = df_samples.query(prepared_sample_query, **sample_query_options) | ||||||||||||||||||||
| df_samples = df_samples.reset_index(drop=True) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Warn if query returned zero results on a non-empty dataset. | ||||||||||||||||||||
|
|
@@ -1197,12 +1198,13 @@ def _prep_sample_selection_cache_params( | |||||||||||||||||||
| # Default the sample_query_options to an empty dict. | ||||||||||||||||||||
| sample_query_options = sample_query_options or {} | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Use the python engine in order to support extension array dtypes, e.g. Float64, Int64, boolean. | ||||||||||||||||||||
| # Validate the query to prevent arbitrary code execution (GH-1292). | ||||||||||||||||||||
| # Get the Pandas Series as a NumPy array of Boolean values. | ||||||||||||||||||||
| # Note: if `prepared_sample_query` is an internal query, this will select all samples, | ||||||||||||||||||||
| # since `sample_metadata` should have already applied the internal query. | ||||||||||||||||||||
| validate_query(prepared_sample_query) | ||||||||||||||||||||
| loc_samples = df_samples.eval( | ||||||||||||||||||||
| prepared_sample_query, **sample_query_options, engine="python" | ||||||||||||||||||||
| prepared_sample_query, **sample_query_options | ||||||||||||||||||||
| ).values | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Convert the sample indices to a list. | ||||||||||||||||||||
|
|
@@ -1368,6 +1370,7 @@ def _setup_sample_symbol( | |||||||||||||||||||
| ) | ||||||||||||||||||||
| data["symbol"] = "" | ||||||||||||||||||||
| for key, value in symbol.items(): | ||||||||||||||||||||
| validate_query(value) | ||||||||||||||||||||
| data.loc[data.query(value).index, "symbol"] = key | ||||||||||||||||||||
| symbol_prepped = "symbol" | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -1421,6 +1424,7 @@ def _setup_sample_colors_plotly( | |||||||||||||||||||
| ) | ||||||||||||||||||||
| data["color"] = "" | ||||||||||||||||||||
| for key, value in color.items(): | ||||||||||||||||||||
| validate_query(value) | ||||||||||||||||||||
| data.loc[data.query(value).index, "color"] = key | ||||||||||||||||||||
| color_prepped = "color" | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -1654,6 +1658,7 @@ def cohorts( | |||||||||||||||||||
| self._cache_cohorts[cache_key] = df_cohorts | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if query is not None: | ||||||||||||||||||||
| validate_query(query) | ||||||||||||||||||||
| df_cohorts = df_cohorts.query(query) | ||||||||||||||||||||
| df_cohorts = df_cohorts.reset_index(drop=True) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -1872,6 +1877,7 @@ def _locate_cohorts(*, cohorts, data, min_cohort_size): | |||||||||||||||||||
|
|
||||||||||||||||||||
| for coh, query in cohorts.items(): | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| validate_query(query) | ||||||||||||||||||||
| loc_coh = data.eval(query).values | ||||||||||||||||||||
| except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e: | ||||||||||||||||||||
|
||||||||||||||||||||
| except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e: | |
| except ( | |
| KeyError, | |
| NameError, | |
| SyntaxError, | |
| TypeError, | |
| AttributeError, | |
| ValueError, | |
| ) as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generic error message for disallowed AST nodes says "Only comparisons, boolean logic, and constants are permitted." but this validator also explicitly allows arithmetic (
ast.BinOpwith+,-,*, etc.) and unary operators. Updating the message to reflect the actual allowlist would avoid confusing users when they hit a validation error.