Skip to content

Commit 7f02664

Browse files
fix: add read timeouts to protocol to prevent indefinite hangs
Protocol reads now use asyncio.wait_for with a configurable timeout (default 15s). After handshake, the timeout is increased to at least the server's heartbeat timeout. Previously, a hung server with an open TCP connection would block the client forever. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 53de6a3 commit 7f02664

4 files changed

Lines changed: 46 additions & 8 deletions

File tree

src/dqliteclient/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def _query_leader(self, address: str) -> str | None:
7171
except (TimeoutError, OSError):
7272
return None
7373

74-
protocol = DqliteProtocol(reader, writer)
74+
protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
7575

7676
try:
7777
await protocol.handshake()

src/dqliteclient/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def connect(self) -> None:
6161
except OSError as e:
6262
raise DqliteConnectionError(f"Failed to connect to {self._address}: {e}") from e
6363

64-
self._protocol = DqliteProtocol(reader, writer)
64+
self._protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
6565

6666
try:
6767
await self._protocol.handshake()

src/dqliteclient/protocol.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
self,
3333
reader: asyncio.StreamReader,
3434
writer: asyncio.StreamWriter,
35+
timeout: float = 15.0,
3536
) -> None:
3637
self._reader = reader
3738
self._writer = writer
@@ -40,6 +41,7 @@ def __init__(
4041
self._buffer = ReadBuffer()
4142
self._client_id = 0
4243
self._heartbeat_timeout = 0
44+
self._timeout = timeout
4345

4446
async def handshake(self, client_id: int = 0) -> int:
4547
"""Perform protocol handshake.
@@ -66,6 +68,10 @@ async def handshake(self, client_id: int = 0) -> int:
6668

6769
self._client_id = client_id
6870
self._heartbeat_timeout = response.heartbeat_timeout
71+
# Use heartbeat timeout for subsequent reads if larger than default
72+
if response.heartbeat_timeout > 0:
73+
heartbeat_seconds = response.heartbeat_timeout / 1000.0
74+
self._timeout = max(self._timeout, heartbeat_seconds)
6975
return response.heartbeat_timeout
7076

7177
async def get_leader(self) -> tuple[int, str]:
@@ -191,23 +197,33 @@ async def query_sql(
191197

192198
return column_names, all_rows
193199

200+
async def _read_data(self) -> bytes:
201+
"""Read data from the stream with timeout."""
202+
try:
203+
data = await asyncio.wait_for(
204+
self._reader.read(4096), timeout=self._timeout
205+
)
206+
except TimeoutError:
207+
raise DqliteConnectionError(
208+
f"Server read timed out after {self._timeout}s"
209+
) from None
210+
if not data:
211+
raise DqliteConnectionError("Connection closed by server")
212+
return data
213+
194214
async def _read_continuation(self) -> RowsResponse:
195215
"""Read and decode a ROWS continuation frame."""
196216
while True:
197217
result = self._decoder.decode_continuation()
198218
if result is not None:
199219
return result
200-
data = await self._reader.read(4096)
201-
if not data:
202-
raise DqliteConnectionError("Connection closed by server")
220+
data = await self._read_data()
203221
self._decoder.feed(data)
204222

205223
async def _read_response(self) -> Message:
206224
"""Read and decode the next response message."""
207225
while not self._decoder.has_message():
208-
data = await self._reader.read(4096)
209-
if not data:
210-
raise DqliteConnectionError("Connection closed by server")
226+
data = await self._read_data()
211227
self._decoder.feed(data)
212228

213229
message = self._decoder.decode()

tests/test_protocol.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,28 @@ async def test_query_sql_multipart(
138138
assert rows[0] == [1, "alice"]
139139
assert rows[1] == [2, "bob"]
140140

141+
async def test_read_timeout(
142+
self,
143+
mock_reader: AsyncMock,
144+
mock_writer: MagicMock,
145+
) -> None:
146+
"""Protocol reads should time out instead of blocking forever."""
147+
import asyncio
148+
149+
protocol = DqliteProtocol(mock_reader, mock_writer, timeout=0.1)
150+
151+
# Simulate a server that hangs (never returns data)
152+
async def hang_forever(*args, **kwargs):
153+
await asyncio.sleep(100)
154+
return b""
155+
156+
mock_reader.read.side_effect = hang_forever
157+
158+
from dqliteclient.exceptions import DqliteConnectionError
159+
160+
with pytest.raises(DqliteConnectionError, match="timed out"):
161+
await protocol.exec_sql(1, "SELECT 1")
162+
141163
async def test_close(
142164
self,
143165
protocol: DqliteProtocol,

0 commit comments

Comments
 (0)