Skip to content

Commit 40909eb

Browse files
fix: enforce per-operation deadline on protocol reads
self._timeout applied to each individual read. A slow server trickling bytes just under the timeout on every chunk could keep a call alive indefinitely. query_sql with N continuation frames was also unbounded: N × self._timeout. Establish a per-operation deadline at the start of _read_response and _read_continuation; _read_data caps each chunk's timeout by the remaining budget. query_sql now threads one deadline through the initial response and every continuation, so the wall-time budget spans the whole call. Also reject non-finite (inf/nan) timeouts at DqliteConnection construction. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7639ddb commit 40909eb

4 files changed

Lines changed: 89 additions & 17 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ def __init__(
8080
database: Database name to open
8181
timeout: Connection timeout in seconds
8282
"""
83-
if timeout <= 0:
84-
raise ValueError(f"timeout must be positive, got {timeout}")
83+
import math
84+
85+
if not math.isfinite(timeout) or timeout <= 0:
86+
raise ValueError(f"timeout must be a positive finite number, got {timeout}")
8587
self._address = address
8688
self._database = database
8789
self._timeout = timeout

src/dqliteclient/protocol.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ async def query_sql(
184184
self._writer.write(request.encode())
185185
await self._send()
186186

187-
response = await self._read_response()
187+
# Single deadline spans the initial response plus every continuation
188+
# frame; otherwise a server that split a reply into N frames could
189+
# legitimately take N * self._timeout to complete.
190+
deadline = self._operation_deadline()
191+
response = await self._read_response(deadline=deadline)
188192

189193
if isinstance(response, FailureResponse):
190194
raise OperationalError(response.code, response.message)
@@ -200,7 +204,7 @@ async def query_sql(
200204
# marker), matching the C dqlite server's wire format.
201205
all_rows = list(response.rows)
202206
while response.has_more:
203-
next_response = await self._read_continuation()
207+
next_response = await self._read_continuation(deadline=deadline)
204208
if not next_response.rows and next_response.has_more:
205209
# Server claimed "more coming" but delivered zero rows in a
206210
# continuation frame. That would spin forever (known
@@ -221,36 +225,67 @@ async def _send(self) -> None:
221225
except (ConnectionError, OSError, RuntimeError) as e:
222226
raise DqliteConnectionError(f"Write failed: {e}") from e
223227

224-
async def _read_data(self) -> bytes:
225-
"""Read data from the stream with timeout.
228+
async def _read_data(self, deadline: float | None = None) -> bytes:
229+
"""Read a chunk from the stream, bounded by a per-operation deadline.
230+
231+
If ``deadline`` is set (monotonic time), the per-chunk timeout is
232+
capped by the remaining budget — a slow-drip server that returned
233+
just under the per-read timeout on every chunk used to be able to
234+
keep a call alive indefinitely.
226235
227236
Transport errors (ConnectionResetError, BrokenPipeError, OSError,
228237
RuntimeError("Transport is closed")) are wrapped in
229238
DqliteConnectionError to match the write-path behaviour.
230239
"""
240+
if deadline is not None:
241+
remaining = deadline - asyncio.get_running_loop().time()
242+
if remaining <= 0:
243+
raise DqliteConnectionError(f"Operation exceeded {self._timeout}s deadline")
244+
timeout = min(remaining, self._timeout)
245+
else:
246+
timeout = self._timeout
231247
try:
232-
data = await asyncio.wait_for(self._reader.read(4096), timeout=self._timeout)
248+
data = await asyncio.wait_for(self._reader.read(4096), timeout=timeout)
233249
except TimeoutError:
234-
raise DqliteConnectionError(f"Server read timed out after {self._timeout}s") from None
250+
raise DqliteConnectionError(f"Server read timed out after {timeout:.1f}s") from None
235251
except (ConnectionError, OSError, RuntimeError) as e:
236252
raise DqliteConnectionError(f"Read failed: {e}") from e
237253
if not data:
238254
raise DqliteConnectionError("Connection closed by server")
239255
return data
240256

241-
async def _read_continuation(self) -> RowsResponse:
242-
"""Read and decode a ROWS continuation frame."""
257+
def _operation_deadline(self) -> float:
258+
"""Deadline (monotonic seconds) for a single protocol operation."""
259+
return asyncio.get_running_loop().time() + self._timeout
260+
261+
async def _read_continuation(self, deadline: float | None = None) -> RowsResponse:
262+
"""Read and decode a ROWS continuation frame.
263+
264+
If ``deadline`` is None, a fresh per-operation deadline is set;
265+
query_sql passes its own deadline so the budget spans every
266+
continuation frame, not each one individually.
267+
"""
268+
if deadline is None:
269+
deadline = self._operation_deadline()
243270
while True:
244271
result = self._decoder.decode_continuation()
245272
if result is not None:
246273
return result
247-
data = await self._read_data()
274+
data = await self._read_data(deadline=deadline)
248275
self._decoder.feed(data)
249276

250-
async def _read_response(self) -> Message:
251-
"""Read and decode the next response message."""
277+
async def _read_response(self, deadline: float | None = None) -> Message:
278+
"""Read and decode the next response message.
279+
280+
If ``deadline`` is None, a fresh per-operation deadline is set for
281+
this one response; callers that span multiple reads (e.g. query_sql
282+
across continuation frames) pass an externally-held deadline so
283+
the cumulative wall time is bounded.
284+
"""
285+
if deadline is None:
286+
deadline = self._operation_deadline()
252287
while not self._decoder.has_message():
253-
data = await self._read_data()
288+
data = await self._read_data(deadline=deadline)
254289
self._decoder.feed(data)
255290

256291
message = self._decoder.decode()

tests/test_connection.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,21 @@ def test_min_valid_port(self) -> None:
9090

9191
class TestDqliteConnection:
9292
def test_zero_timeout_raises(self) -> None:
93-
with pytest.raises(ValueError, match="timeout must be positive"):
93+
with pytest.raises(ValueError, match="timeout must be"):
9494
DqliteConnection("localhost:9001", timeout=0)
9595

9696
def test_negative_timeout_raises(self) -> None:
97-
with pytest.raises(ValueError, match="timeout must be positive"):
97+
with pytest.raises(ValueError, match="timeout must be"):
9898
DqliteConnection("localhost:9001", timeout=-1)
9999

100+
def test_infinite_timeout_raises(self) -> None:
101+
with pytest.raises(ValueError, match="finite"):
102+
DqliteConnection("localhost:9001", timeout=float("inf"))
103+
104+
def test_nan_timeout_raises(self) -> None:
105+
with pytest.raises(ValueError, match="finite"):
106+
DqliteConnection("localhost:9001", timeout=float("nan"))
107+
100108
def test_init(self) -> None:
101109
conn = DqliteConnection("localhost:9001", database="test", timeout=5.0)
102110
assert conn.address == "localhost:9001"

tests/test_protocol.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from dqliteclient.exceptions import OperationalError, ProtocolError
7+
from dqliteclient.exceptions import DqliteConnectionError, OperationalError, ProtocolError
88
from dqliteclient.protocol import DqliteProtocol
99
from dqlitewire.messages import (
1010
FailureResponse,
@@ -69,6 +69,33 @@ async def test_reader_errors_are_wrapped(
6969
await protocol._read_data()
7070
assert exc_info.value.__cause__ is err
7171

72+
async def test_read_response_enforces_operation_deadline(
73+
self,
74+
protocol: DqliteProtocol,
75+
mock_reader: AsyncMock,
76+
) -> None:
77+
"""Even if each individual read returns just under the per-read
78+
timeout, the cumulative operation deadline must fire.
79+
"""
80+
import asyncio
81+
import time
82+
83+
protocol._timeout = 0.2
84+
85+
async def drip_forever(_n: int) -> bytes:
86+
await asyncio.sleep(0.1)
87+
return b"\x00" # 1 byte, never completes a message
88+
89+
mock_reader.read.side_effect = drip_forever
90+
91+
t0 = time.monotonic()
92+
with pytest.raises(DqliteConnectionError, match="deadline|timed out"):
93+
await protocol._read_response()
94+
elapsed = time.monotonic() - t0
95+
assert elapsed < 1.0, (
96+
f"_read_response must bail at the operation deadline; took {elapsed:.3f}s"
97+
)
98+
7299
async def test_handshake_success(
73100
self,
74101
protocol: DqliteProtocol,

0 commit comments

Comments
 (0)