Skip to content

Commit 84aa17f

Browse files
fix: invalidate connection on CancelledError to prevent decoder corruption
When a protocol read is cancelled mid-flight, the MessageDecoder may contain partial data from an incomplete message, or the TCP buffer may contain an orphaned response. Either way, the connection is irrecoverably corrupted — subsequent reads would decode garbage or return wrong responses. Add `except BaseException: self._invalidate(); raise` to execute(), fetch(), fetchall(), and fetchval() so that CancelledError (and other BaseException subclasses) properly invalidate the connection rather than leaving it in a poisoned state. Closes #085 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent afe4a7f commit 84aa17f

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ async def execute(self, sql: str, params: Sequence[Any] | None = None) -> tuple[
151151
if e.code in _LEADER_ERROR_CODES:
152152
self._invalidate()
153153
raise
154+
except BaseException:
155+
self._invalidate()
156+
raise
154157

155158
async def fetch(self, sql: str, params: Sequence[Any] | None = None) -> list[dict[str, Any]]:
156159
"""Execute a query and return results as list of dicts."""
@@ -164,6 +167,9 @@ async def fetch(self, sql: str, params: Sequence[Any] | None = None) -> list[dic
164167
if e.code in _LEADER_ERROR_CODES:
165168
self._invalidate()
166169
raise
170+
except BaseException:
171+
self._invalidate()
172+
raise
167173
return [dict(zip(columns, row, strict=True)) for row in rows]
168174

169175
async def fetchall(self, sql: str, params: Sequence[Any] | None = None) -> list[list[Any]]:
@@ -178,6 +184,9 @@ async def fetchall(self, sql: str, params: Sequence[Any] | None = None) -> list[
178184
if e.code in _LEADER_ERROR_CODES:
179185
self._invalidate()
180186
raise
187+
except BaseException:
188+
self._invalidate()
189+
raise
181190
return rows
182191

183192
async def fetchone(
@@ -199,6 +208,9 @@ async def fetchval(self, sql: str, params: Sequence[Any] | None = None) -> Any:
199208
if e.code in _LEADER_ERROR_CODES:
200209
self._invalidate()
201210
raise
211+
except BaseException:
212+
self._invalidate()
213+
raise
202214
if rows and rows[0]:
203215
return rows[0][0]
204216
return None

tests/test_connection.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,56 @@ async def mock_execute(sql: str, params=None):
226226
# _in_transaction was cleaned up
227227
assert not conn._in_transaction
228228

229+
async def test_cancellation_invalidates_connection(self) -> None:
230+
"""CancelledError during a query must invalidate the connection."""
231+
import asyncio
232+
233+
conn = DqliteConnection("localhost:9001")
234+
235+
mock_reader = AsyncMock()
236+
mock_writer = MagicMock()
237+
mock_writer.drain = AsyncMock()
238+
mock_writer.close = MagicMock()
239+
mock_writer.wait_closed = AsyncMock()
240+
241+
from dqlitewire.messages import DbResponse, WelcomeResponse
242+
243+
responses = [
244+
WelcomeResponse(heartbeat_timeout=15000).encode(),
245+
DbResponse(db_id=1).encode(),
246+
]
247+
mock_reader.read.side_effect = responses
248+
249+
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
250+
await conn.connect()
251+
252+
assert conn.is_connected
253+
254+
# Make the reader hang forever (will be cancelled)
255+
read_entered = asyncio.Event()
256+
257+
async def hanging_read(*args, **kwargs):
258+
read_entered.set()
259+
await asyncio.sleep(100)
260+
261+
mock_reader.read.side_effect = hanging_read
262+
263+
async def do_execute():
264+
await conn.execute("INSERT INTO t VALUES (1)")
265+
266+
task = asyncio.create_task(do_execute())
267+
await read_entered.wait()
268+
task.cancel()
269+
270+
with pytest.raises(asyncio.CancelledError):
271+
await task
272+
273+
# Connection must be invalidated — the decoder may have partial data
274+
assert not conn.is_connected, (
275+
"Connection should be invalidated after CancelledError to prevent "
276+
"decoder corruption from partial reads"
277+
)
278+
229279
async def test_connection_invalidated_after_protocol_error(self) -> None:
230280
"""After a connection error, is_connected should return False."""
231281
conn = DqliteConnection("localhost:9001")

0 commit comments

Comments
 (0)