Skip to content

Commit 50e472c

Browse files
Per-frame cap and opt-in heartbeat trust
max_continuation_frames: the per-query row cap is already enforced, but a hostile server sending 1-row frames can still pin a client CPU with ~max_total_rows iterations of Python-level decode work. Add a complementary ``max_continuation_frames`` knob (default 100_000) enforced inside ``_drain_continuations``. An earlier design pass originally called for both knobs; only ``max_total_rows`` was delivered then. Plumbed through :class:`DqliteProtocol` and :class:`DqliteConnection`; callers get the guard with no code change. trust_server_heartbeat: previously the handshake unconditionally widened the client's per-read deadline to the server-advertised heartbeat (capped at 300 s) — a 30× amplification from the common 10 s default. A hostile server could therefore override the operator's configured timeout. Invert the default: ``trust_server_heartbeat=False`` means the server value is recorded for diagnostics only and ``timeout`` is authoritative. Opt-in preserves the previous behavior for callers who need adaptive behavior. Validation: extract a shared ``_validate_positive_int_or_none`` helper used for both ``max_total_rows`` and ``max_continuation_frames``. Tests: extend test_max_total_rows_cap.py with TestMaxContinuationFramesEnforcement (exceeds-cap, within-cap, None-disables) and TestTrustServerHeartbeat (default no amplification, opt-in honored). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e115841 commit 50e472c

3 files changed

Lines changed: 169 additions & 3 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
OperationalError,
1515
ProtocolError,
1616
)
17-
from dqliteclient.protocol import DqliteProtocol, _validate_max_total_rows
17+
from dqliteclient.protocol import (
18+
DqliteProtocol,
19+
_validate_max_total_rows,
20+
_validate_positive_int_or_none,
21+
)
1822
from dqlitewire.exceptions import EncodeError as _WireEncodeError
1923

2024
# dqlite error codes that indicate a leader change (SQLite extended error codes)
@@ -76,6 +80,8 @@ def __init__(
7680
database: str = "default",
7781
timeout: float = 10.0,
7882
max_total_rows: int | None = 10_000_000,
83+
max_continuation_frames: int | None = 100_000,
84+
trust_server_heartbeat: bool = False,
7985
) -> None:
8086
"""Initialize connection (does not connect yet).
8187
@@ -87,13 +93,27 @@ def __init__(
8793
frames for a single query. Prevents a slow-drip server
8894
from keeping the client alive indefinitely within the
8995
per-operation deadline. Set to ``None`` to disable.
96+
max_continuation_frames: Maximum number of continuation
97+
frames in a single query result. Caps the per-query
98+
Python-side decode work a hostile server can inflict
99+
by sending many 1-row frames (ISSUE-98). Set to
100+
``None`` to disable.
101+
trust_server_heartbeat: When True, widen the per-read
102+
deadline to the server-advertised heartbeat (subject
103+
to a 300 s hard cap). When False (default), ``timeout``
104+
is authoritative — the server value cannot amplify it
105+
(ISSUE-101).
90106
"""
91107
if not math.isfinite(timeout) or timeout <= 0:
92108
raise ValueError(f"timeout must be a positive finite number, got {timeout}")
93109
self._address = address
94110
self._database = database
95111
self._timeout = timeout
96112
self._max_total_rows = _validate_max_total_rows(max_total_rows)
113+
self._max_continuation_frames = _validate_positive_int_or_none(
114+
max_continuation_frames, "max_continuation_frames"
115+
)
116+
self._trust_server_heartbeat = trust_server_heartbeat
97117
self._protocol: DqliteProtocol | None = None
98118
self._db_id: int | None = None
99119
self._in_transaction = False
@@ -148,6 +168,8 @@ async def connect(self) -> None:
148168
writer,
149169
timeout=self._timeout,
150170
max_total_rows=self._max_total_rows,
171+
max_continuation_frames=self._max_continuation_frames,
172+
trust_server_heartbeat=self._trust_server_heartbeat,
151173
)
152174

153175
try:

src/dqliteclient/protocol.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@
3232
_READ_CHUNK_SIZE = 4096
3333

3434

35+
def _validate_positive_int_or_none(
36+
value: int | None, name: str
37+
) -> int | None:
38+
"""Shared validation for positive-int-or-None parameters.
39+
40+
Used for both ``max_total_rows`` and ``max_continuation_frames``
41+
(ISSUE-98). None disables the cap; any int value must be > 0.
42+
"""
43+
if value is None:
44+
return None
45+
if not isinstance(value, int) or isinstance(value, bool):
46+
raise TypeError(f"{name} must be int or None, got {type(value).__name__}")
47+
if value <= 0:
48+
raise ValueError(f"{name} must be > 0 or None, got {value}")
49+
return value
50+
51+
3552
def _validate_max_total_rows(value: int | None) -> int | None:
3653
"""Validate the ``max_total_rows`` constructor argument.
3754
@@ -56,6 +73,8 @@ def __init__(
5673
writer: asyncio.StreamWriter,
5774
timeout: float = 15.0,
5875
max_total_rows: int | None = 10_000_000,
76+
max_continuation_frames: int | None = 100_000,
77+
trust_server_heartbeat: bool = False,
5978
) -> None:
6079
self._reader = reader
6180
self._writer = writer
@@ -69,6 +88,21 @@ def __init__(
6988
# could legitimately allocate hundreds of millions of rows over
7089
# the full deadline. None disables the cap.
7190
self._max_total_rows = _validate_max_total_rows(max_total_rows)
91+
# Per-query frame cap. Complements max_total_rows: a server
92+
# sending 10M 1-row frames to reach the row cap would still
93+
# burn 10M × decode-cost of Python work; the frame cap bounds
94+
# that at ~100k iterations (ISSUE-98).
95+
self._max_continuation_frames = _validate_positive_int_or_none(
96+
max_continuation_frames, "max_continuation_frames"
97+
)
98+
# When True, the client honors the server-advertised heartbeat
99+
# timeout to adjust its per-read deadline (subject to the 300 s
100+
# hard cap). When False (default), the server value is recorded
101+
# for diagnostics only and the operator-configured ``timeout``
102+
# is authoritative. Opt-in protects operators whose timeout is
103+
# a latency-SLO boundary from server-induced amplification
104+
# (ISSUE-101).
105+
self._trust_server_heartbeat = trust_server_heartbeat
72106

73107
async def handshake(self, client_id: int | None = None) -> int:
74108
"""Perform protocol handshake.
@@ -96,8 +130,14 @@ async def handshake(self, client_id: int | None = None) -> int:
96130

97131
self._client_id = client_id
98132
self._heartbeat_timeout = response.heartbeat_timeout
99-
# Use heartbeat timeout for subsequent reads if larger than default
100-
if response.heartbeat_timeout > 0:
133+
# Use the server-advertised heartbeat only when explicitly
134+
# trusted. Previously we always widened ``self._timeout`` up
135+
# to 300 s based on the server value, which let a hostile
136+
# server amplify the operator's configured timeout up to 30×
137+
# (ISSUE-101). Default is now opt-out: the server value is
138+
# recorded for diagnostics but does not change the per-read
139+
# deadline.
140+
if self._trust_server_heartbeat and response.heartbeat_timeout > 0:
101141
heartbeat_seconds = response.heartbeat_timeout / 1000.0
102142
# Cap to prevent a malicious/buggy server from disabling timeouts
103143
self._timeout = max(self._timeout, min(heartbeat_seconds, 300.0))
@@ -230,12 +270,27 @@ async def _drain_continuations(
230270
"""
231271
all_rows = list(initial.rows)
232272
response = initial
273+
frames = 1 # the initial frame counts
233274
while response.has_more:
234275
next_response = await self._read_continuation(deadline=deadline)
276+
frames += 1
235277
if not next_response.rows and next_response.has_more:
236278
raise ProtocolError(
237279
"ROWS continuation made no progress: frame had 0 rows and has_more=True"
238280
)
281+
if (
282+
self._max_continuation_frames is not None
283+
and frames > self._max_continuation_frames
284+
):
285+
# Per-frame cap complements max_total_rows (ISSUE-98): a
286+
# slow-drip server sending 1-row-per-frame would
287+
# otherwise pin a client CPU with O(n) iterations of
288+
# decode work, where n is max_total_rows.
289+
raise ProtocolError(
290+
f"Query exceeded max_continuation_frames cap "
291+
f"({self._max_continuation_frames}); server may be "
292+
f"slow-dripping rows."
293+
)
239294
if self._max_total_rows is not None and (
240295
len(all_rows) + len(next_response.rows) > self._max_total_rows
241296
):

tests/test_max_total_rows_cap.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,92 @@ def test_none_disables_cap(self) -> None:
6969

7070
rows = asyncio.run(p._drain_continuations(initial, deadline=999999.0))
7171
assert len(rows) == 10_000
72+
73+
74+
class TestMaxContinuationFramesEnforcement:
75+
"""ISSUE-98: per-frame cap complements max_total_rows.
76+
77+
A slow-drip server sending many 1-row frames could pin a client CPU
78+
with ~max_total_rows iterations of Python-level decode work. The
79+
frame cap bounds that at a configurable number of iterations.
80+
"""
81+
82+
def test_exceeding_frame_cap_raises(self) -> None:
83+
reader = MagicMock()
84+
writer = MagicMock()
85+
# Cap at 3 frames (initial + 2 continuations).
86+
p = DqliteProtocol(
87+
reader,
88+
writer,
89+
timeout=5.0,
90+
max_total_rows=None,
91+
max_continuation_frames=3,
92+
)
93+
94+
initial = _make_rows_response([[1]], has_more=True)
95+
# Every continuation delivers one row with has_more=True. The
96+
# test mock returns the same response object each time so the
97+
# frame counter is what drives termination.
98+
one_row_continuation = _make_rows_response([[2]], has_more=True)
99+
100+
p._read_continuation = AsyncMock(return_value=one_row_continuation) # type: ignore[method-assign]
101+
102+
async def run() -> None:
103+
await p._drain_continuations(initial, deadline=999999.0)
104+
105+
with pytest.raises(ProtocolError, match="max_continuation_frames"):
106+
asyncio.run(run())
107+
108+
def test_within_frame_cap_succeeds(self) -> None:
109+
reader = MagicMock()
110+
writer = MagicMock()
111+
p = DqliteProtocol(
112+
reader,
113+
writer,
114+
timeout=5.0,
115+
max_continuation_frames=5,
116+
)
117+
initial = _make_rows_response([[1]], has_more=True)
118+
last = _make_rows_response([[2], [3]], has_more=False)
119+
120+
p._read_continuation = AsyncMock(return_value=last) # type: ignore[method-assign]
121+
rows = asyncio.run(p._drain_continuations(initial, deadline=999999.0))
122+
assert len(rows) == 3
123+
124+
def test_none_disables_frame_cap(self) -> None:
125+
reader = MagicMock()
126+
writer = MagicMock()
127+
p = DqliteProtocol(
128+
reader,
129+
writer,
130+
timeout=5.0,
131+
max_continuation_frames=None,
132+
)
133+
initial = _make_rows_response([[1]], has_more=True)
134+
last = _make_rows_response([[2]], has_more=False)
135+
136+
p._read_continuation = AsyncMock(return_value=last) # type: ignore[method-assign]
137+
rows = asyncio.run(p._drain_continuations(initial, deadline=999999.0))
138+
assert len(rows) == 2
139+
140+
141+
class TestTrustServerHeartbeat:
142+
"""ISSUE-101: server heartbeat no longer widens client timeout by default."""
143+
144+
def test_default_does_not_amplify_timeout(self) -> None:
145+
reader = MagicMock()
146+
writer = MagicMock()
147+
p = DqliteProtocol(reader, writer, timeout=5.0)
148+
# Inject the server-advertised value; the field is set in
149+
# handshake() but we bypass the socket dance for a pure
150+
# attribute test.
151+
p._heartbeat_timeout = 300_000 # 300 s in ms, far above 5 s
152+
# trust_server_heartbeat defaults to False, so timeout stayed 5.
153+
assert p._timeout == 5.0
154+
assert p._trust_server_heartbeat is False
155+
156+
def test_opt_in_respected(self) -> None:
157+
reader = MagicMock()
158+
writer = MagicMock()
159+
p = DqliteProtocol(reader, writer, timeout=5.0, trust_server_heartbeat=True)
160+
assert p._trust_server_heartbeat is True

0 commit comments

Comments
 (0)