Skip to content

Commit d6b8da5

Browse files
Plumb max_continuation_frames + trust_server_heartbeat through pool and cluster
Post-review follow-up. The security review flagged that these knobs existed only on DqliteProtocol/DqliteConnection — operators using ConnectionPool or ClusterClient.connect (which is the common path) had no way to tune them. Add both parameters to: - ConnectionPool.__init__ - ClusterClient.connect and forward them through _create_connection / try_connect to the underlying DqliteConnection. Also strengthens the trust_server_heartbeat tests: - opt-in with handshake actually amplifies timeout to the server value - default ignores handshake heartbeat entirely (no amplification) - opt-in respects the 300 s hard cap Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8a6c9f6 commit d6b8da5

3 files changed

Lines changed: 87 additions & 6 deletions

File tree

src/dqliteclient/cluster.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,17 @@ async def connect(
160160
database: str = "default",
161161
*,
162162
max_total_rows: int | None = 10_000_000,
163+
max_continuation_frames: int | None = 100_000,
164+
trust_server_heartbeat: bool = False,
163165
max_attempts: int | None = None,
164166
) -> DqliteConnection:
165167
"""Connect to the cluster leader.
166168
167-
Returns a connection to the current leader. ``max_total_rows``
168-
is forwarded to the underlying :class:`DqliteConnection` so
169-
callers (including :class:`ConnectionPool`) can tune the
170-
cumulative row cap from one place.
169+
Returns a connection to the current leader. ``max_total_rows``,
170+
``max_continuation_frames``, and ``trust_server_heartbeat`` are
171+
forwarded to the underlying :class:`DqliteConnection` so callers
172+
(including :class:`ConnectionPool`) can tune security/DoS
173+
governors from one place.
171174
172175
``max_attempts`` overrides the default
173176
:data:`_DEFAULT_CONNECT_MAX_ATTEMPTS` (ISSUE-109).
@@ -194,6 +197,8 @@ async def try_connect() -> DqliteConnection:
194197
database=database,
195198
timeout=self._timeout,
196199
max_total_rows=max_total_rows,
200+
max_continuation_frames=max_continuation_frames,
201+
trust_server_heartbeat=trust_server_heartbeat,
197202
)
198203
await conn.connect()
199204
return conn

src/dqliteclient/pool.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from dqliteclient.connection import DqliteConnection
1212
from dqliteclient.exceptions import DqliteConnectionError
1313
from dqliteclient.node_store import NodeStore
14-
from dqliteclient.protocol import _validate_max_total_rows
14+
from dqliteclient.protocol import _validate_max_total_rows, _validate_positive_int_or_none
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -62,6 +62,8 @@ def __init__(
6262
cluster: ClusterClient | None = None,
6363
node_store: NodeStore | None = None,
6464
max_total_rows: int | None = 10_000_000,
65+
max_continuation_frames: int | None = 100_000,
66+
trust_server_heartbeat: bool = False,
6567
) -> None:
6668
"""Initialize connection pool.
6769
@@ -99,6 +101,16 @@ def __init__(
99101
connection inherits the same governor. ``None`` disables
100102
the cap entirely (not recommended in production —
101103
bounds memory against slow-drip attacks).
104+
max_continuation_frames: Per-query continuation-frame cap
105+
(ISSUE-98). Complements ``max_total_rows``: a server
106+
sending 1-row-per-frame can inflict O(n) Python decode
107+
work where n is the row cap; the frame cap bounds that.
108+
Forwarded to every :class:`DqliteConnection`.
109+
trust_server_heartbeat: When True, the per-read deadline on
110+
every connection widens to the server-advertised
111+
heartbeat (up to a 300 s hard cap). Default False —
112+
operator-configured ``timeout`` is authoritative and
113+
the server cannot amplify it (ISSUE-101).
102114
"""
103115
if min_size < 0:
104116
raise ValueError(f"min_size must be non-negative, got {min_size}")
@@ -119,6 +131,10 @@ def __init__(
119131
self._max_size = max_size
120132
self._timeout = timeout
121133
self._max_total_rows = _validate_max_total_rows(max_total_rows)
134+
self._max_continuation_frames = _validate_positive_int_or_none(
135+
max_continuation_frames, "max_continuation_frames"
136+
)
137+
self._trust_server_heartbeat = trust_server_heartbeat
122138

123139
if cluster is not None:
124140
self._cluster = cluster
@@ -172,7 +188,10 @@ async def _create_connection(self) -> DqliteConnection:
172188
reservation.
173189
"""
174190
return await self._cluster.connect(
175-
database=self._database, max_total_rows=self._max_total_rows
191+
database=self._database,
192+
max_total_rows=self._max_total_rows,
193+
max_continuation_frames=self._max_continuation_frames,
194+
trust_server_heartbeat=self._trust_server_heartbeat,
176195
)
177196

178197
async def _release_reservation(self) -> None:

tests/test_max_total_rows_cap.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,60 @@ def test_opt_in_respected(self) -> None:
158158
writer = MagicMock()
159159
p = DqliteProtocol(reader, writer, timeout=5.0, trust_server_heartbeat=True)
160160
assert p._trust_server_heartbeat is True
161+
162+
def test_opt_in_amplifies_timeout_via_handshake(self) -> None:
163+
"""With trust_server_heartbeat=True, a handshake reply with a
164+
larger-than-local heartbeat widens the per-read deadline (up to
165+
the 300 s hard cap). Exercises the actual handshake code path
166+
rather than just checking the flag is set (ISSUE-101 review)."""
167+
from dqlitewire.messages import WelcomeResponse
168+
169+
reader = AsyncMock()
170+
writer = MagicMock()
171+
writer.drain = AsyncMock()
172+
writer.close = MagicMock()
173+
# heartbeat=60_000 ms = 60 s, well above timeout=5 s, under 300 s cap
174+
reader.read.side_effect = [
175+
WelcomeResponse(heartbeat_timeout=60_000).encode(),
176+
]
177+
p = DqliteProtocol(reader, writer, timeout=5.0, trust_server_heartbeat=True)
178+
asyncio.run(p.handshake())
179+
assert p._timeout == 60.0, (
180+
f"trust_server_heartbeat=True should widen timeout to 60s, got {p._timeout}"
181+
)
182+
183+
def test_default_ignores_handshake_heartbeat(self) -> None:
184+
"""With trust_server_heartbeat=False (default), a handshake
185+
reply with a larger heartbeat is recorded but does NOT widen
186+
the per-read deadline."""
187+
from dqlitewire.messages import WelcomeResponse
188+
189+
reader = AsyncMock()
190+
writer = MagicMock()
191+
writer.drain = AsyncMock()
192+
writer.close = MagicMock()
193+
reader.read.side_effect = [
194+
WelcomeResponse(heartbeat_timeout=300_000).encode(),
195+
]
196+
p = DqliteProtocol(reader, writer, timeout=5.0) # default = False
197+
asyncio.run(p.handshake())
198+
# Server value recorded for diagnostics, but timeout unchanged.
199+
assert p._heartbeat_timeout == 300_000
200+
assert p._timeout == 5.0, f"default should not widen timeout, got {p._timeout}"
201+
202+
def test_opt_in_respects_hard_300s_cap(self) -> None:
203+
"""Even with trust_server_heartbeat=True, a server sending an
204+
absurdly large heartbeat is clamped at 300 s."""
205+
from dqlitewire.messages import WelcomeResponse
206+
207+
reader = AsyncMock()
208+
writer = MagicMock()
209+
writer.drain = AsyncMock()
210+
writer.close = MagicMock()
211+
# heartbeat=3_600_000 ms = 1 h; clamp to 300 s.
212+
reader.read.side_effect = [
213+
WelcomeResponse(heartbeat_timeout=3_600_000).encode(),
214+
]
215+
p = DqliteProtocol(reader, writer, timeout=5.0, trust_server_heartbeat=True)
216+
asyncio.run(p.handshake())
217+
assert p._timeout == 300.0

0 commit comments

Comments
 (0)