Skip to content

Commit 66c50dc

Browse files
Expose DoS-protection knobs via URL + pin async dialect_description
Post-review follow-up. Adds two entries to the URL query allowlist: - max_continuation_frames (int) — per-query frame cap - trust_server_heartbeat (bool) — opt-in server-heartbeat trust bool parsing uses a URL-friendly parser because bool("False") is truthy. Accepts 1/true/yes/on as true; everything else as false. The URL value converter now handles the str | tuple[str, ...] shape that SQLAlchemy uses for multi-valued query params (takes the last occurrence). Also adds the async variant of the dialect_description pin test (correctness review flagged the sync-only coverage). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 66a494b commit 66c50dc

3 files changed

Lines changed: 58 additions & 3 deletions

File tree

src/sqlalchemydqlite/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import datetime
55
import warnings
6+
from collections.abc import Callable
67
from typing import Any
78

89
from sqlalchemy import types as sqltypes
@@ -111,7 +112,14 @@ def import_dbapi(cls) -> Any:
111112

112113
# Whitelist of URL query parameters we forward to the DBAPI connect
113114
# call. Unknown keys raise ``ArgumentError`` so typos surface.
114-
_URL_QUERY_ALLOWED: dict[str, type] = {"timeout": float, "max_total_rows": int}
115+
# ``trust_server_heartbeat`` uses a URL-friendly bool parser because
116+
# bool("False") evaluates truthy (non-empty string).
117+
_URL_QUERY_ALLOWED: dict[str, Callable[[str], Any]] = {
118+
"timeout": float,
119+
"max_total_rows": int,
120+
"max_continuation_frames": int,
121+
"trust_server_heartbeat": lambda s: s.strip().lower() in ("1", "true", "yes", "on"),
122+
}
115123

116124
def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]:
117125
"""Create connection arguments from URL.
@@ -137,11 +145,15 @@ def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]:
137145
f"Allowed: {sorted(self._URL_QUERY_ALLOWED)}"
138146
)
139147
converter = self._URL_QUERY_ALLOWED[key]
148+
# URL query values can be str or tuple[str, ...] (when a key
149+
# appears multiple times). Take the last occurrence.
150+
raw_str = raw[-1] if isinstance(raw, tuple) else raw
140151
try:
141-
kwargs[key] = converter(raw)
152+
kwargs[key] = converter(raw_str)
142153
except (TypeError, ValueError) as e:
143154
raise ArgumentError(
144-
f"Cannot convert URL query {key}={raw!r} to {converter.__name__}: {e}"
155+
f"Cannot convert URL query {key}={raw!r} to "
156+
f"{getattr(converter, '__name__', 'expected type')}: {e}"
145157
) from e
146158

147159
return [], kwargs

tests/test_dialect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def test_dialect_description(self) -> None:
5959
# (ISSUE-89).
6060
assert DqliteDialect().dialect_description == "dqlite+dqlitedbapi"
6161

62+
def test_async_dialect_description(self) -> None:
63+
# Mirror ISSUE-89 pin for the async dialect; review agent flagged
64+
# that the sync test alone could mask an async-side drift.
65+
assert DqliteDialect_aio().dialect_description == "dqlite+dqlitedbapi_aio"
66+
6267

6368
class TestDqliteDialectAio:
6469
def test_dialect_name(self) -> None:

tests/test_dialect_dialect_config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,44 @@ def test_timeout_and_max_total_rows_together(self) -> None:
8888
assert kwargs["timeout"] == 3.5
8989
assert kwargs["max_total_rows"] == 250
9090

91+
def test_max_continuation_frames_forwarded(self) -> None:
92+
"""ISSUE-98 URL plumbing — post-review follow-up."""
93+
dialect = DqliteDialect()
94+
url = make_url("dqlite://host:19001/db?max_continuation_frames=500")
95+
_, kwargs = dialect.create_connect_args(url)
96+
assert kwargs["max_continuation_frames"] == 500
97+
98+
def test_max_continuation_frames_rejects_non_int(self) -> None:
99+
dialect = DqliteDialect()
100+
url = make_url("dqlite://host:19001/db?max_continuation_frames=nope")
101+
with pytest.raises(ArgumentError, match="int"):
102+
dialect.create_connect_args(url)
103+
104+
@pytest.mark.parametrize(
105+
"raw,expected",
106+
[
107+
("1", True),
108+
("true", True),
109+
("True", True),
110+
("YES", True),
111+
("on", True),
112+
("0", False),
113+
("false", False),
114+
("no", False),
115+
("off", False),
116+
],
117+
)
118+
def test_trust_server_heartbeat_parses_boolean(self, raw: str, expected: bool) -> None:
119+
"""ISSUE-101 URL plumbing — post-review follow-up.
120+
121+
URL values arrive as strings; bool("False") would evaluate
122+
truthy if used directly, so we use a dedicated parser.
123+
"""
124+
dialect = DqliteDialect()
125+
url = make_url(f"dqlite://host:19001/db?trust_server_heartbeat={raw}")
126+
_, kwargs = dialect.create_connect_args(url)
127+
assert kwargs["trust_server_heartbeat"] is expected
128+
91129

92130
class TestDoPingNarrowExceptions:
93131
def test_returns_true_on_success(self) -> None:

0 commit comments

Comments
 (0)