Skip to content

Commit 29ec507

Browse files
fix: invalidate connection state after protocol errors
After a ConnectionError or ProtocolError, the connection's _protocol and _db_id are now cleared so is_connected returns False. Previously, broken connections still reported is_connected=True, causing the pool to hand out dead connections. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2e80611 commit 29ec507

2 files changed

Lines changed: 60 additions & 5 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from contextlib import asynccontextmanager
66
from typing import Any
77

8-
from dqliteclient.exceptions import ConnectionError
8+
from dqliteclient.exceptions import ConnectionError, ProtocolError
99
from dqliteclient.protocol import DqliteProtocol
1010

1111

@@ -92,24 +92,41 @@ def _ensure_connected(self) -> tuple[DqliteProtocol, int]:
9292
raise ConnectionError("Not connected")
9393
return self._protocol, self._db_id
9494

95+
def _invalidate(self) -> None:
96+
"""Mark the connection as broken after an unrecoverable error."""
97+
self._protocol = None
98+
self._db_id = None
99+
95100
async def execute(self, sql: str, params: list[Any] | None = None) -> tuple[int, int]:
96101
"""Execute a SQL statement.
97102
98103
Returns (last_insert_id, rows_affected).
99104
"""
100105
protocol, db_id = self._ensure_connected()
101-
return await protocol.exec_sql(db_id, sql, params)
106+
try:
107+
return await protocol.exec_sql(db_id, sql, params)
108+
except (ConnectionError, ProtocolError):
109+
self._invalidate()
110+
raise
102111

103112
async def fetch(self, sql: str, params: list[Any] | None = None) -> list[dict[str, Any]]:
104113
"""Execute a query and return results as list of dicts."""
105114
protocol, db_id = self._ensure_connected()
106-
columns, rows = await protocol.query_sql(db_id, sql, params)
115+
try:
116+
columns, rows = await protocol.query_sql(db_id, sql, params)
117+
except (ConnectionError, ProtocolError):
118+
self._invalidate()
119+
raise
107120
return [dict(zip(columns, row, strict=True)) for row in rows]
108121

109122
async def fetchall(self, sql: str, params: list[Any] | None = None) -> list[list[Any]]:
110123
"""Execute a query and return results as list of lists."""
111124
protocol, db_id = self._ensure_connected()
112-
_, rows = await protocol.query_sql(db_id, sql, params)
125+
try:
126+
_, rows = await protocol.query_sql(db_id, sql, params)
127+
except (ConnectionError, ProtocolError):
128+
self._invalidate()
129+
raise
113130
return rows
114131

115132
async def fetchone(self, sql: str, params: list[Any] | None = None) -> dict[str, Any] | None:
@@ -120,7 +137,11 @@ async def fetchone(self, sql: str, params: list[Any] | None = None) -> dict[str,
120137
async def fetchval(self, sql: str, params: list[Any] | None = None) -> Any:
121138
"""Execute a query and return the first column of the first row."""
122139
protocol, db_id = self._ensure_connected()
123-
_, rows = await protocol.query_sql(db_id, sql, params)
140+
try:
141+
_, rows = await protocol.query_sql(db_id, sql, params)
142+
except (ConnectionError, ProtocolError):
143+
self._invalidate()
144+
raise
124145
if rows and rows[0]:
125146
return rows[0][0]
126147
return None

tests/test_connection.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,37 @@ async def mock_execute(sql: str, params=None):
142142
assert "ROLLBACK" in call_log
143143
# _in_transaction was cleaned up
144144
assert not conn._in_transaction
145+
146+
async def test_connection_invalidated_after_protocol_error(self) -> None:
147+
"""After a connection error, is_connected should return False."""
148+
conn = DqliteConnection("localhost:9001")
149+
150+
mock_reader = AsyncMock()
151+
mock_writer = MagicMock()
152+
mock_writer.drain = AsyncMock()
153+
mock_writer.close = MagicMock()
154+
mock_writer.wait_closed = AsyncMock()
155+
156+
from dqlitewire.messages import DbResponse, WelcomeResponse
157+
158+
responses = [
159+
WelcomeResponse(heartbeat_timeout=15000).encode(),
160+
DbResponse(db_id=1).encode(),
161+
]
162+
mock_reader.read.side_effect = responses
163+
164+
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
165+
await conn.connect()
166+
167+
assert conn.is_connected
168+
169+
# Now make the reader return empty (connection closed)
170+
mock_reader.read.side_effect = [b""]
171+
172+
from dqliteclient.exceptions import ConnectionError as DqliteConnectionError
173+
174+
with pytest.raises(DqliteConnectionError):
175+
await conn.execute("SELECT 1")
176+
177+
# Connection should be invalidated
178+
assert not conn.is_connected

0 commit comments

Comments
 (0)