Skip to content

Commit 860fa24

Browse files
fix: add re-entrance guard to prevent concurrent coroutine access
DqliteConnection now detects when two coroutines attempt to use the same connection concurrently and raises InterfaceError immediately, matching the behavior of asyncpg and other mature async DB clients. The dqlite wire protocol is strictly request-response with no pipelining. Without this guard, concurrent coroutines would interleave their sends and reads on the shared StreamWriter/StreamReader, causing response mix-ups, ProtocolErrors, or silent data corruption. Each public method (execute, fetch, fetchall, fetchval) sets an _in_use flag before the protocol operation and clears it after. If a second coroutine tries to use the connection while _in_use is True, InterfaceError is raised with a clear message directing users to use ConnectionPool for concurrent operations. Also adds InterfaceError to the exceptions module and public API. Closes #077 Also resolves #079 (_in_transaction flag race) and #084 (concurrent connect TOCTOU) since the guard prevents concurrent access entirely. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 84aa17f commit 860fa24

4 files changed

Lines changed: 96 additions & 1 deletion

File tree

src/dqliteclient/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ClusterError,
77
DqliteConnectionError,
88
DqliteError,
9+
InterfaceError,
910
OperationalError,
1011
ProtocolError,
1112
)
@@ -23,6 +24,7 @@
2324
"MemoryNodeStore",
2425
"DqliteError",
2526
"DqliteConnectionError",
27+
"InterfaceError",
2628
"ProtocolError",
2729
"ClusterError",
2830
"OperationalError",

src/dqliteclient/connection.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from contextlib import asynccontextmanager
77
from typing import Any
88

9-
from dqliteclient.exceptions import DqliteConnectionError, OperationalError, ProtocolError
9+
from dqliteclient.exceptions import (
10+
DqliteConnectionError,
11+
InterfaceError,
12+
OperationalError,
13+
ProtocolError,
14+
)
1015
from dqliteclient.protocol import DqliteProtocol
1116

1217
# dqlite error codes that indicate a leader change (SQLite extended error codes)
@@ -68,6 +73,7 @@ def __init__(
6873
self._protocol: DqliteProtocol | None = None
6974
self._db_id: int | None = None
7075
self._in_transaction = False
76+
self._in_use = False
7177

7278
@property
7379
def address(self) -> str:
@@ -127,6 +133,15 @@ def _ensure_connected(self) -> tuple[DqliteProtocol, int]:
127133
raise DqliteConnectionError("Not connected")
128134
return self._protocol, self._db_id
129135

136+
def _check_in_use(self) -> None:
137+
"""Raise if another coroutine is using this connection."""
138+
if self._in_use:
139+
raise InterfaceError(
140+
"Cannot perform operation: another operation is in progress on this "
141+
"connection. DqliteConnection does not support concurrent coroutine "
142+
"access. Use a ConnectionPool to manage multiple concurrent operations."
143+
)
144+
130145
def _invalidate(self) -> None:
131146
"""Mark the connection as broken after an unrecoverable error."""
132147
if self._protocol is not None:
@@ -141,7 +156,9 @@ async def execute(self, sql: str, params: Sequence[Any] | None = None) -> tuple[
141156
142157
Returns (last_insert_id, rows_affected).
143158
"""
159+
self._check_in_use()
144160
protocol, db_id = self._ensure_connected()
161+
self._in_use = True
145162
try:
146163
return await protocol.exec_sql(db_id, sql, params)
147164
except (DqliteConnectionError, ProtocolError):
@@ -154,10 +171,14 @@ async def execute(self, sql: str, params: Sequence[Any] | None = None) -> tuple[
154171
except BaseException:
155172
self._invalidate()
156173
raise
174+
finally:
175+
self._in_use = False
157176

158177
async def fetch(self, sql: str, params: Sequence[Any] | None = None) -> list[dict[str, Any]]:
159178
"""Execute a query and return results as list of dicts."""
179+
self._check_in_use()
160180
protocol, db_id = self._ensure_connected()
181+
self._in_use = True
161182
try:
162183
columns, rows = await protocol.query_sql(db_id, sql, params)
163184
except (DqliteConnectionError, ProtocolError):
@@ -170,11 +191,15 @@ async def fetch(self, sql: str, params: Sequence[Any] | None = None) -> list[dic
170191
except BaseException:
171192
self._invalidate()
172193
raise
194+
finally:
195+
self._in_use = False
173196
return [dict(zip(columns, row, strict=True)) for row in rows]
174197

175198
async def fetchall(self, sql: str, params: Sequence[Any] | None = None) -> list[list[Any]]:
176199
"""Execute a query and return results as list of lists."""
200+
self._check_in_use()
177201
protocol, db_id = self._ensure_connected()
202+
self._in_use = True
178203
try:
179204
_, rows = await protocol.query_sql(db_id, sql, params)
180205
except (DqliteConnectionError, ProtocolError):
@@ -187,6 +212,8 @@ async def fetchall(self, sql: str, params: Sequence[Any] | None = None) -> list[
187212
except BaseException:
188213
self._invalidate()
189214
raise
215+
finally:
216+
self._in_use = False
190217
return rows
191218

192219
async def fetchone(
@@ -198,7 +225,9 @@ async def fetchone(
198225

199226
async def fetchval(self, sql: str, params: Sequence[Any] | None = None) -> Any:
200227
"""Execute a query and return the first column of the first row."""
228+
self._check_in_use()
201229
protocol, db_id = self._ensure_connected()
230+
self._in_use = True
202231
try:
203232
_, rows = await protocol.query_sql(db_id, sql, params)
204233
except (DqliteConnectionError, ProtocolError):
@@ -211,6 +240,8 @@ async def fetchval(self, sql: str, params: Sequence[Any] | None = None) -> Any:
211240
except BaseException:
212241
self._invalidate()
213242
raise
243+
finally:
244+
self._in_use = False
214245
if rows and rows[0]:
215246
return rows[0][0]
216247
return None

src/dqliteclient/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ class ProtocolError(DqliteError):
1313
"""Protocol-level error."""
1414

1515

16+
class InterfaceError(DqliteError):
17+
"""Misuse of the client interface (e.g. concurrent access on a connection)."""
18+
19+
1620
class ClusterError(DqliteError):
1721
"""Cluster-related error (leader not found, etc)."""
1822

tests/test_connection.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,3 +617,61 @@ async def test_not_leader_error_invalidates_connection(self) -> None:
617617

618618
# Connection should be invalidated after a leader error
619619
assert not conn.is_connected
620+
621+
async def test_concurrent_coroutines_raises_interface_error(self) -> None:
622+
"""Two coroutines using the same connection must raise InterfaceError."""
623+
import asyncio
624+
625+
from dqliteclient.exceptions import InterfaceError
626+
627+
conn = DqliteConnection("localhost:9001")
628+
629+
mock_reader = AsyncMock()
630+
mock_writer = MagicMock()
631+
mock_writer.drain = AsyncMock()
632+
mock_writer.close = MagicMock()
633+
mock_writer.wait_closed = AsyncMock()
634+
635+
from dqlitewire.messages import DbResponse, WelcomeResponse
636+
637+
responses = [
638+
WelcomeResponse(heartbeat_timeout=15000).encode(),
639+
DbResponse(db_id=1).encode(),
640+
]
641+
mock_reader.read.side_effect = responses
642+
643+
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
644+
await conn.connect()
645+
646+
# Make execute hang so two coroutines overlap
647+
first_entered = asyncio.Event()
648+
649+
async def slow_exec_sql(db_id, sql, params=None):
650+
first_entered.set()
651+
await asyncio.sleep(10)
652+
return (0, 1)
653+
654+
conn._protocol.exec_sql = AsyncMock(side_effect=slow_exec_sql) # type: ignore[union-attr]
655+
656+
errors: list[Exception] = []
657+
658+
async def first_execute():
659+
await conn.execute("INSERT INTO t VALUES (1)")
660+
661+
async def second_execute():
662+
await first_entered.wait()
663+
try:
664+
await conn.execute("INSERT INTO t VALUES (2)")
665+
except InterfaceError as e:
666+
errors.append(e)
667+
668+
task1 = asyncio.create_task(first_execute())
669+
task2 = asyncio.create_task(second_execute())
670+
671+
await task2 # second should raise InterfaceError
672+
task1.cancel()
673+
with pytest.raises(asyncio.CancelledError):
674+
await task1
675+
676+
assert len(errors) == 1
677+
assert "another operation is in progress" in str(errors[0])

0 commit comments

Comments
 (0)