Skip to content

Commit 7639ddb

Browse files
fix: wrap writer.drain and _read_data transport errors as DqliteConnectionError
Each protocol method awaited writer.drain() directly; ConnectionReset, BrokenPipeError, OSError, and RuntimeError("Transport is closed") leaked to callers raw. Application code catching DqliteConnectionError to retry missed write-path transport errors entirely. _read_data had the same gap on the read path: only TimeoutError was wrapped. Add a _send() helper that wraps the drain and route all seven call sites through it, and extend _read_data's except to the same transport-error set. Parametrized tests cover the four known leaking types for both paths and assert __cause__ is preserved. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a7912ea commit 7639ddb

2 files changed

Lines changed: 75 additions & 8 deletions

File tree

src/dqliteclient/protocol.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def handshake(self, client_id: int | None = None) -> int:
5656
# Send protocol version + client registration together
5757
request = ClientRequest(client_id=client_id)
5858
self._writer.write(MessageEncoder().encode_handshake() + request.encode())
59-
await self._writer.drain()
59+
await self._send()
6060

6161
# Read welcome response
6262
response = await self._read_response()
@@ -83,7 +83,7 @@ async def get_leader(self) -> tuple[int, str]:
8383
"""
8484
request = LeaderRequest()
8585
self._writer.write(request.encode())
86-
await self._writer.drain()
86+
await self._send()
8787

8888
response = await self._read_response()
8989

@@ -102,7 +102,7 @@ async def open_database(self, name: str, flags: int = 0, vfs: str = "") -> int:
102102
"""
103103
request = OpenRequest(name=name, flags=flags, vfs=vfs)
104104
self._writer.write(request.encode())
105-
await self._writer.drain()
105+
await self._send()
106106

107107
response = await self._read_response()
108108

@@ -121,7 +121,7 @@ async def prepare(self, db_id: int, sql: str) -> tuple[int, int]:
121121
"""
122122
request = PrepareRequest(db_id=db_id, sql=sql)
123123
self._writer.write(request.encode())
124-
await self._writer.drain()
124+
await self._send()
125125

126126
response = await self._read_response()
127127

@@ -137,7 +137,7 @@ async def finalize(self, db_id: int, stmt_id: int) -> None:
137137
"""Finalize (close) a prepared statement."""
138138
request = FinalizeRequest(db_id=db_id, stmt_id=stmt_id)
139139
self._writer.write(request.encode())
140-
await self._writer.drain()
140+
await self._send()
141141

142142
response = await self._read_response()
143143

@@ -159,7 +159,7 @@ async def exec_sql(
159159
"""
160160
request = ExecSqlRequest(db_id=db_id, sql=sql, params=params if params is not None else [])
161161
self._writer.write(request.encode())
162-
await self._writer.drain()
162+
await self._send()
163163

164164
response = await self._read_response()
165165

@@ -182,7 +182,7 @@ async def query_sql(
182182
"""
183183
request = QuerySqlRequest(db_id=db_id, sql=sql, params=params if params is not None else [])
184184
self._writer.write(request.encode())
185-
await self._writer.drain()
185+
await self._send()
186186

187187
response = await self._read_response()
188188

@@ -214,12 +214,26 @@ async def query_sql(
214214

215215
return column_names, all_rows
216216

217+
async def _send(self) -> None:
218+
"""Drain the writer, wrapping transport errors as DqliteConnectionError."""
219+
try:
220+
await self._writer.drain()
221+
except (ConnectionError, OSError, RuntimeError) as e:
222+
raise DqliteConnectionError(f"Write failed: {e}") from e
223+
217224
async def _read_data(self) -> bytes:
218-
"""Read data from the stream with timeout."""
225+
"""Read data from the stream with timeout.
226+
227+
Transport errors (ConnectionResetError, BrokenPipeError, OSError,
228+
RuntimeError("Transport is closed")) are wrapped in
229+
DqliteConnectionError to match the write-path behaviour.
230+
"""
219231
try:
220232
data = await asyncio.wait_for(self._reader.read(4096), timeout=self._timeout)
221233
except TimeoutError:
222234
raise DqliteConnectionError(f"Server read timed out after {self._timeout}s") from None
235+
except (ConnectionError, OSError, RuntimeError) as e:
236+
raise DqliteConnectionError(f"Read failed: {e}") from e
223237
if not data:
224238
raise DqliteConnectionError("Connection closed by server")
225239
return data

tests/test_protocol.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,59 @@ class TestDqliteProtocol:
1616
def protocol(self, mock_reader: AsyncMock, mock_writer: MagicMock) -> DqliteProtocol:
1717
return DqliteProtocol(mock_reader, mock_writer)
1818

19+
@pytest.mark.parametrize(
20+
"err",
21+
[
22+
ConnectionResetError("peer reset"),
23+
BrokenPipeError("pipe gone"),
24+
OSError("generic os error"),
25+
RuntimeError("Transport is closed"),
26+
],
27+
)
28+
async def test_writer_drain_errors_are_wrapped(
29+
self,
30+
protocol: DqliteProtocol,
31+
mock_reader: AsyncMock,
32+
mock_writer: MagicMock,
33+
err: BaseException,
34+
) -> None:
35+
"""Transport errors raised by writer.drain() must surface as
36+
DqliteConnectionError so callers catching the client's error
37+
hierarchy don't miss them, with the original attached via __cause__.
38+
"""
39+
from dqliteclient.exceptions import DqliteConnectionError
40+
41+
mock_writer.drain = AsyncMock(side_effect=err)
42+
43+
with pytest.raises(DqliteConnectionError) as exc_info:
44+
await protocol.handshake()
45+
assert exc_info.value.__cause__ is err
46+
47+
@pytest.mark.parametrize(
48+
"err",
49+
[
50+
ConnectionResetError("peer reset mid-read"),
51+
BrokenPipeError("pipe gone mid-read"),
52+
OSError("generic os error mid-read"),
53+
RuntimeError("Transport is closed mid-read"),
54+
],
55+
)
56+
async def test_reader_errors_are_wrapped(
57+
self,
58+
protocol: DqliteProtocol,
59+
mock_reader: AsyncMock,
60+
err: BaseException,
61+
) -> None:
62+
"""Transport errors on _read_data must also surface as
63+
DqliteConnectionError, matching the write-path behaviour.
64+
"""
65+
from dqliteclient.exceptions import DqliteConnectionError
66+
67+
mock_reader.read.side_effect = err
68+
with pytest.raises(DqliteConnectionError) as exc_info:
69+
await protocol._read_data()
70+
assert exc_info.value.__cause__ is err
71+
1972
async def test_handshake_success(
2073
self,
2174
protocol: DqliteProtocol,

0 commit comments

Comments
 (0)