Skip to content

Commit 9fe7530

Browse files
Add protocol operation serialization to prevent wire corruption
The dqlite wire protocol is single-request-at-a-time per connection. Without serialization, concurrent access (asyncio.gather or multiple threads) corrupts the TCP stream, causing response mix-ups or crashes. Add asyncio.Lock (_op_lock) to AsyncConnection, acquired around all protocol calls in AsyncCursor.execute(), commit(), and rollback(). Add threading.Lock (_op_lock) to Connection._run_sync() so only one thread uses the protocol at a time. Add asyncio.Lock (_connect_lock) to Connection._get_async_connection() to prevent duplicate connection creation on concurrent first use. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e01e0d1 commit 9fe7530

File tree

4 files changed

+247
-28
lines changed

4 files changed

+247
-28
lines changed

src/dqlitedbapi/aio/connection.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
self._async_conn: DqliteConnection | None = None
3232
self._closed = False
3333
self._connect_lock = asyncio.Lock()
34+
self._op_lock = asyncio.Lock()
3435

3536
async def _ensure_connection(self) -> DqliteConnection:
3637
"""Ensure the underlying connection is established."""
@@ -76,15 +77,17 @@ async def commit(self) -> None:
7677
raise InterfaceError("Connection is closed")
7778

7879
if self._async_conn is not None:
79-
await self._async_conn.execute("COMMIT")
80+
async with self._op_lock:
81+
await self._async_conn.execute("COMMIT")
8082

8183
async def rollback(self) -> None:
8284
"""Roll back any pending transaction."""
8385
if self._closed:
8486
raise InterfaceError("Connection is closed")
8587

8688
if self._async_conn is not None:
87-
await self._async_conn.execute("ROLLBACK")
89+
async with self._op_lock:
90+
await self._async_conn.execute("ROLLBACK")
8891

8992
def cursor(self) -> AsyncCursor:
9093
"""Return a new AsyncCursor object.

src/dqlitedbapi/aio/cursor.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,28 @@ async def execute(
9191
if conn._protocol is None or conn._db_id is None:
9292
raise InternalError("Connection protocol not initialized")
9393

94-
try:
95-
if is_query:
96-
columns, rows = await conn._protocol.query_sql(conn._db_id, operation, params)
97-
self._description = [(name, None, None, None, None, None, None) for name in columns]
98-
self._rows = [tuple(row) for row in rows]
99-
self._row_index = 0
100-
self._rowcount = len(rows)
101-
else:
102-
last_id, affected = await conn._protocol.exec_sql(conn._db_id, operation, params)
103-
self._lastrowid = last_id
104-
self._rowcount = affected
105-
self._description = None
106-
self._rows = []
107-
except (OperationalError, InterfaceError, InternalError):
108-
raise
109-
except Exception as e:
110-
raise OperationalError(str(e)) from e
94+
async with self._connection._op_lock:
95+
try:
96+
if is_query:
97+
columns, rows = await conn._protocol.query_sql(conn._db_id, operation, params)
98+
self._description = [
99+
(name, None, None, None, None, None, None) for name in columns
100+
]
101+
self._rows = [tuple(row) for row in rows]
102+
self._row_index = 0
103+
self._rowcount = len(rows)
104+
else:
105+
last_id, affected = await conn._protocol.exec_sql(
106+
conn._db_id, operation, params
107+
)
108+
self._lastrowid = last_id
109+
self._rowcount = affected
110+
self._description = None
111+
self._rows = []
112+
except (OperationalError, InterfaceError, InternalError):
113+
raise
114+
except Exception as e:
115+
raise OperationalError(str(e)) from e
111116

112117
return self
113118

src/dqlitedbapi/connection.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(
3535
self._loop: asyncio.AbstractEventLoop | None = None
3636
self._thread: threading.Thread | None = None
3737
self._loop_lock = threading.Lock()
38+
self._op_lock = threading.Lock()
39+
self._connect_lock: asyncio.Lock | None = None
3840

3941
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
4042
"""Ensure a dedicated event loop is running in a background thread.
@@ -55,22 +57,36 @@ def _run_sync(self, coro: Any) -> Any:
5557
"""Run an async coroutine from sync code.
5658
5759
Submits the coroutine to the dedicated background event loop
58-
and blocks until the result is available.
60+
and blocks until the result is available. The operation lock
61+
ensures only one operation runs at a time, preventing wire
62+
protocol corruption from concurrent access.
5963
"""
60-
loop = self._ensure_loop()
61-
future = asyncio.run_coroutine_threadsafe(coro, loop)
62-
try:
63-
return future.result(timeout=self._timeout)
64-
except TimeoutError as e:
65-
future.cancel()
66-
raise OperationalError(f"Operation timed out after {self._timeout} seconds") from e
64+
with self._op_lock:
65+
loop = self._ensure_loop()
66+
future = asyncio.run_coroutine_threadsafe(coro, loop)
67+
try:
68+
# Future.result() provides a happens-before memory barrier,
69+
# ensuring all writes by the event loop thread are visible here.
70+
return future.result(timeout=self._timeout)
71+
except TimeoutError as e:
72+
future.cancel()
73+
raise OperationalError(f"Operation timed out after {self._timeout} seconds") from e
6774

6875
async def _get_async_connection(self) -> DqliteConnection:
6976
"""Get or create the underlying async connection."""
7077
if self._closed:
7178
raise InterfaceError("Connection is closed")
7279

73-
if self._async_conn is None:
80+
if self._async_conn is not None:
81+
return self._async_conn
82+
83+
if self._connect_lock is None:
84+
self._connect_lock = asyncio.Lock()
85+
86+
async with self._connect_lock:
87+
if self._async_conn is not None:
88+
return self._async_conn
89+
7490
conn = DqliteConnection(
7591
self._address,
7692
database=self._database,
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Tests for protocol operation serialization.
2+
3+
The dqlite wire protocol is single-request-at-a-time per connection.
4+
Concurrent protocol operations must be serialized to prevent wire corruption.
5+
"""
6+
7+
import asyncio
8+
import threading
9+
from unittest.mock import AsyncMock, MagicMock, patch
10+
11+
import pytest
12+
13+
from dqlitedbapi.aio.connection import AsyncConnection
14+
from dqlitedbapi.connection import Connection
15+
from dqlitedbapi.cursor import Cursor
16+
17+
18+
class TestAsyncProtocolSerialization:
19+
"""Test that concurrent async operations are serialized."""
20+
21+
@pytest.mark.asyncio
22+
async def test_concurrent_execute_is_serialized(self) -> None:
23+
"""Two concurrent execute() calls must not overlap on the wire.
24+
25+
Without serialization, both coroutines would call query_sql/exec_sql
26+
concurrently, corrupting the TCP stream. With serialization, they
27+
must run one after the other.
28+
"""
29+
conn = AsyncConnection("localhost:9001")
30+
31+
call_log: list[tuple[str, str]] = [] # (operation, phase) pairs
32+
33+
async def mock_query_sql(db_id: int, sql: str, params: object) -> tuple:
34+
call_log.append((sql, "start"))
35+
await asyncio.sleep(0.05) # Simulate network I/O
36+
call_log.append((sql, "end"))
37+
return (["id"], [[1]])
38+
39+
async def mock_exec_sql(db_id: int, sql: str, params: object) -> tuple:
40+
call_log.append((sql, "start"))
41+
await asyncio.sleep(0.05)
42+
call_log.append((sql, "end"))
43+
return (0, 1)
44+
45+
with patch("dqlitedbapi.aio.connection.DqliteConnection") as MockDqliteConn:
46+
mock_instance = AsyncMock()
47+
mock_instance.connect = AsyncMock()
48+
mock_instance._protocol = MagicMock()
49+
mock_instance._protocol.query_sql = mock_query_sql
50+
mock_instance._protocol.exec_sql = mock_exec_sql
51+
mock_instance._db_id = 0
52+
MockDqliteConn.return_value = mock_instance
53+
54+
await conn.connect()
55+
56+
cursor1 = conn.cursor()
57+
cursor2 = conn.cursor()
58+
59+
# Run two operations concurrently
60+
await asyncio.gather(
61+
cursor1.execute("SELECT 1"),
62+
cursor2.execute("INSERT INTO t VALUES (1)"),
63+
)
64+
65+
# With proper serialization, operations must not interleave.
66+
# The call_log should show: start/end of one op, then start/end of the other.
67+
# NOT: start/start/end/end (interleaved).
68+
assert len(call_log) == 4
69+
# First operation must complete before second starts
70+
assert call_log[0][1] == "start"
71+
assert call_log[1][1] == "end"
72+
assert call_log[2][1] == "start"
73+
assert call_log[3][1] == "end"
74+
75+
76+
class TestSyncProtocolSerialization:
77+
"""Test that concurrent sync operations are serialized."""
78+
79+
def test_concurrent_run_sync_is_serialized(self) -> None:
80+
"""Two threads calling _run_sync must not overlap on the event loop.
81+
82+
Without serialization, both threads submit coroutines concurrently
83+
to the same event loop, where they interleave at await points.
84+
"""
85+
conn = Connection("localhost:9001", timeout=5.0)
86+
87+
call_log: list[tuple[str, str]] = []
88+
log_lock = threading.Lock()
89+
90+
async def mock_query_sql(db_id: int, sql: str, params: object) -> tuple:
91+
with log_lock:
92+
call_log.append((sql, "start"))
93+
await asyncio.sleep(0.05)
94+
with log_lock:
95+
call_log.append((sql, "end"))
96+
return (["id"], [[1]])
97+
98+
async def mock_exec_sql(db_id: int, sql: str, params: object) -> tuple:
99+
with log_lock:
100+
call_log.append((sql, "start"))
101+
await asyncio.sleep(0.05)
102+
with log_lock:
103+
call_log.append((sql, "end"))
104+
return (0, 1)
105+
106+
with patch("dqlitedbapi.connection.DqliteConnection") as MockDqliteConn:
107+
mock_instance = AsyncMock()
108+
mock_instance.connect = AsyncMock()
109+
mock_instance._protocol = MagicMock()
110+
mock_instance._protocol.query_sql = mock_query_sql
111+
mock_instance._protocol.exec_sql = mock_exec_sql
112+
mock_instance._db_id = 0
113+
MockDqliteConn.return_value = mock_instance
114+
115+
cursor1 = Cursor(conn)
116+
cursor2 = Cursor(conn)
117+
118+
barrier = threading.Barrier(2)
119+
errors: list[Exception] = []
120+
121+
def thread_work(cursor: Cursor, sql: str) -> None:
122+
try:
123+
barrier.wait(timeout=5)
124+
cursor.execute(sql)
125+
except Exception as e:
126+
errors.append(e)
127+
128+
t1 = threading.Thread(target=thread_work, args=(cursor1, "SELECT 1"))
129+
t2 = threading.Thread(target=thread_work, args=(cursor2, "INSERT INTO t VALUES (1)"))
130+
t1.start()
131+
t2.start()
132+
t1.join(timeout=10)
133+
t2.join(timeout=10)
134+
135+
assert not errors, f"Threads raised errors: {errors}"
136+
137+
# Operations must be serialized: one completes before the other starts
138+
assert len(call_log) == 4
139+
assert call_log[0][1] == "start"
140+
assert call_log[1][1] == "end"
141+
assert call_log[2][1] == "start"
142+
assert call_log[3][1] == "end"
143+
144+
conn.close()
145+
146+
147+
class TestSyncConnectionLazyInitRace:
148+
"""Test that _get_async_connection doesn't create duplicate connections."""
149+
150+
def test_concurrent_first_use_creates_single_connection(self) -> None:
151+
"""Two threads using a connection for the first time must not
152+
create two underlying DqliteConnection instances."""
153+
conn = Connection("localhost:9001", timeout=5.0)
154+
155+
connect_count = 0
156+
count_lock = threading.Lock()
157+
158+
async def slow_connect() -> None:
159+
nonlocal connect_count
160+
with count_lock:
161+
connect_count += 1
162+
await asyncio.sleep(0.05)
163+
164+
with patch("dqlitedbapi.connection.DqliteConnection") as MockDqliteConn:
165+
mock_instance = AsyncMock()
166+
mock_instance.connect = slow_connect
167+
mock_instance._protocol = MagicMock()
168+
mock_instance._protocol.query_sql = AsyncMock(return_value=(["id"], [[1]]))
169+
mock_instance._db_id = 0
170+
MockDqliteConn.return_value = mock_instance
171+
172+
barrier = threading.Barrier(2)
173+
errors: list[Exception] = []
174+
175+
def thread_work() -> None:
176+
try:
177+
barrier.wait(timeout=5)
178+
cursor = conn.cursor()
179+
cursor.execute("SELECT 1")
180+
except Exception as e:
181+
errors.append(e)
182+
183+
t1 = threading.Thread(target=thread_work)
184+
t2 = threading.Thread(target=thread_work)
185+
t1.start()
186+
t2.start()
187+
t1.join(timeout=10)
188+
t2.join(timeout=10)
189+
190+
assert not errors, f"Threads raised errors: {errors}"
191+
# Only one DqliteConnection should have been created
192+
assert MockDqliteConn.call_count == 1
193+
assert connect_count == 1
194+
195+
conn.close()

0 commit comments

Comments
 (0)