Skip to content

Commit bfc0bb8

Browse files
Enforce thread safety and harden close() against races
Add thread-identity check like sqlite3 stdlib: store the creator thread ID and raise ProgrammingError if the connection or cursor is used from a different thread. This turns silent data corruption into a clear error message. Harden close(): set _closed=True immediately and return early on double-close, preventing races where close() tears down resources while operations are in flight. Apply to both sync and async paths. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9fe7530 commit bfc0bb8

File tree

5 files changed

+183
-113
lines changed

5 files changed

+183
-113
lines changed

src/dqlitedbapi/aio/connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ async def connect(self) -> None:
6666

6767
async def close(self) -> None:
6868
"""Close the connection."""
69+
if self._closed:
70+
return
71+
self._closed = True
6972
if self._async_conn is not None:
7073
await self._async_conn.close()
7174
self._async_conn = None
72-
self._closed = True
7375

7476
async def commit(self) -> None:
7577
"""Commit any pending transaction."""

src/dqlitedbapi/connection.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dqliteclient import DqliteConnection
99
from dqlitedbapi.cursor import Cursor
10-
from dqlitedbapi.exceptions import InterfaceError, OperationalError
10+
from dqlitedbapi.exceptions import InterfaceError, OperationalError, ProgrammingError
1111

1212

1313
class Connection:
@@ -37,6 +37,17 @@ def __init__(
3737
self._loop_lock = threading.Lock()
3838
self._op_lock = threading.Lock()
3939
self._connect_lock: asyncio.Lock | None = None
40+
self._creator_thread = threading.get_ident()
41+
42+
def _check_thread(self) -> None:
43+
"""Raise ProgrammingError if called from a different thread than the creator."""
44+
current = threading.get_ident()
45+
if current != self._creator_thread:
46+
raise ProgrammingError(
47+
f"Connection objects created in a thread can only be used in that "
48+
f"same thread. The object was created in thread id "
49+
f"{self._creator_thread} and this is thread id {current}."
50+
)
4051

4152
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
4253
"""Ensure a dedicated event loop is running in a background thread.
@@ -102,6 +113,10 @@ async def _get_async_connection(self) -> DqliteConnection:
102113

103114
def close(self) -> None:
104115
"""Close the connection."""
116+
self._check_thread()
117+
if self._closed:
118+
return
119+
self._closed = True
105120
try:
106121
if self._async_conn is not None:
107122
with contextlib.suppress(Exception):
@@ -115,10 +130,10 @@ def close(self) -> None:
115130
self._loop.close()
116131
self._loop = None
117132
self._thread = None
118-
self._closed = True
119133

120134
def commit(self) -> None:
121135
"""Commit any pending transaction."""
136+
self._check_thread()
122137
if self._closed:
123138
raise InterfaceError("Connection is closed")
124139

@@ -132,6 +147,7 @@ async def _commit_async(self) -> None:
132147

133148
def rollback(self) -> None:
134149
"""Roll back any pending transaction."""
150+
self._check_thread()
135151
if self._closed:
136152
raise InterfaceError("Connection is closed")
137153

@@ -145,6 +161,7 @@ async def _rollback_async(self) -> None:
145161

146162
def cursor(self) -> Cursor:
147163
"""Return a new Cursor object."""
164+
self._check_thread()
148165
if self._closed:
149166
raise InterfaceError("Connection is closed")
150167
return Cursor(self)

src/dqlitedbapi/cursor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _check_closed(self) -> None:
8282

8383
def execute(self, operation: str, parameters: Sequence[Any] | None = None) -> "Cursor":
8484
"""Execute a database operation (query or command)."""
85+
self._connection._check_thread()
8586
self._check_closed()
8687

8788
self._connection._run_sync(self._execute_async(operation, parameters))
@@ -123,6 +124,7 @@ async def _execute_async(self, operation: str, parameters: Sequence[Any] | None
123124

124125
def executemany(self, operation: str, seq_of_parameters: Sequence[Sequence[Any]]) -> "Cursor":
125126
"""Execute a database operation multiple times."""
127+
self._connection._check_thread()
126128
self._check_closed()
127129

128130
self._connection._run_sync(self._executemany_async(operation, seq_of_parameters))
@@ -145,6 +147,7 @@ def _check_result_set(self) -> None:
145147

146148
def fetchone(self) -> tuple[Any, ...] | None:
147149
"""Fetch the next row of a query result set."""
150+
self._connection._check_thread()
148151
self._check_closed()
149152
self._check_result_set()
150153

@@ -157,6 +160,7 @@ def fetchone(self) -> tuple[Any, ...] | None:
157160

158161
def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
159162
"""Fetch the next set of rows of a query result."""
163+
self._connection._check_thread()
160164
self._check_closed()
161165
self._check_result_set()
162166

@@ -174,6 +178,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
174178

175179
def fetchall(self) -> list[tuple[Any, ...]]:
176180
"""Fetch all remaining rows of a query result."""
181+
self._connection._check_thread()
177182
self._check_closed()
178183
self._check_result_set()
179184

@@ -183,6 +188,7 @@ def fetchall(self) -> list[tuple[Any, ...]]:
183188

184189
def close(self) -> None:
185190
"""Close the cursor."""
191+
self._connection._check_thread()
186192
self._closed = True
187193
self._rows = []
188194
self._description = None

tests/test_protocol_serialization.py

Lines changed: 19 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -74,122 +74,31 @@ async def mock_exec_sql(db_id: int, sql: str, params: object) -> tuple:
7474

7575

7676
class TestSyncProtocolSerialization:
77-
"""Test that concurrent sync operations are serialized."""
77+
"""Test that concurrent sync operations from wrong threads are rejected."""
7878

79-
def test_concurrent_run_sync_is_serialized(self) -> None:
80-
"""Two threads calling _run_sync must not overlap on the event loop.
79+
def test_cross_thread_execute_raises_programming_error(self) -> None:
80+
"""Threads sharing a connection must get ProgrammingError.
8181
82-
Without serialization, both threads submit coroutines concurrently
83-
to the same event loop, where they interleave at await points.
82+
The thread-identity check (like sqlite3) prevents cross-thread
83+
access before it reaches the protocol layer.
8484
"""
85-
conn = Connection("localhost:9001", timeout=5.0)
86-
87-
call_log: list[tuple[str, str]] = []
88-
log_lock = threading.Lock()
89-
90-
async def mock_query_sql(db_id: int, sql: str, params: object) -> tuple:
91-
with log_lock:
92-
call_log.append((sql, "start"))
93-
await asyncio.sleep(0.05)
94-
with log_lock:
95-
call_log.append((sql, "end"))
96-
return (["id"], [[1]])
97-
98-
async def mock_exec_sql(db_id: int, sql: str, params: object) -> tuple:
99-
with log_lock:
100-
call_log.append((sql, "start"))
101-
await asyncio.sleep(0.05)
102-
with log_lock:
103-
call_log.append((sql, "end"))
104-
return (0, 1)
105-
106-
with patch("dqlitedbapi.connection.DqliteConnection") as MockDqliteConn:
107-
mock_instance = AsyncMock()
108-
mock_instance.connect = AsyncMock()
109-
mock_instance._protocol = MagicMock()
110-
mock_instance._protocol.query_sql = mock_query_sql
111-
mock_instance._protocol.exec_sql = mock_exec_sql
112-
mock_instance._db_id = 0
113-
MockDqliteConn.return_value = mock_instance
114-
115-
cursor1 = Cursor(conn)
116-
cursor2 = Cursor(conn)
85+
from dqlitedbapi.exceptions import ProgrammingError
11786

118-
barrier = threading.Barrier(2)
119-
errors: list[Exception] = []
120-
121-
def thread_work(cursor: Cursor, sql: str) -> None:
122-
try:
123-
barrier.wait(timeout=5)
124-
cursor.execute(sql)
125-
except Exception as e:
126-
errors.append(e)
127-
128-
t1 = threading.Thread(target=thread_work, args=(cursor1, "SELECT 1"))
129-
t2 = threading.Thread(target=thread_work, args=(cursor2, "INSERT INTO t VALUES (1)"))
130-
t1.start()
131-
t2.start()
132-
t1.join(timeout=10)
133-
t2.join(timeout=10)
134-
135-
assert not errors, f"Threads raised errors: {errors}"
136-
137-
# Operations must be serialized: one completes before the other starts
138-
assert len(call_log) == 4
139-
assert call_log[0][1] == "start"
140-
assert call_log[1][1] == "end"
141-
assert call_log[2][1] == "start"
142-
assert call_log[3][1] == "end"
143-
144-
conn.close()
145-
146-
147-
class TestSyncConnectionLazyInitRace:
148-
"""Test that _get_async_connection doesn't create duplicate connections."""
149-
150-
def test_concurrent_first_use_creates_single_connection(self) -> None:
151-
"""Two threads using a connection for the first time must not
152-
create two underlying DqliteConnection instances."""
15387
conn = Connection("localhost:9001", timeout=5.0)
88+
cursor = Cursor(conn)
15489

155-
connect_count = 0
156-
count_lock = threading.Lock()
90+
errors: list[Exception] = []
15791

158-
async def slow_connect() -> None:
159-
nonlocal connect_count
160-
with count_lock:
161-
connect_count += 1
162-
await asyncio.sleep(0.05)
92+
def thread_work() -> None:
93+
try:
94+
cursor.execute("SELECT 1")
95+
except Exception as e:
96+
errors.append(e)
16397

164-
with patch("dqlitedbapi.connection.DqliteConnection") as MockDqliteConn:
165-
mock_instance = AsyncMock()
166-
mock_instance.connect = slow_connect
167-
mock_instance._protocol = MagicMock()
168-
mock_instance._protocol.query_sql = AsyncMock(return_value=(["id"], [[1]]))
169-
mock_instance._db_id = 0
170-
MockDqliteConn.return_value = mock_instance
98+
t = threading.Thread(target=thread_work)
99+
t.start()
100+
t.join(timeout=5)
171101

172-
barrier = threading.Barrier(2)
173-
errors: list[Exception] = []
174-
175-
def thread_work() -> None:
176-
try:
177-
barrier.wait(timeout=5)
178-
cursor = conn.cursor()
179-
cursor.execute("SELECT 1")
180-
except Exception as e:
181-
errors.append(e)
182-
183-
t1 = threading.Thread(target=thread_work)
184-
t2 = threading.Thread(target=thread_work)
185-
t1.start()
186-
t2.start()
187-
t1.join(timeout=10)
188-
t2.join(timeout=10)
189-
190-
assert not errors, f"Threads raised errors: {errors}"
191-
# Only one DqliteConnection should have been created
192-
assert MockDqliteConn.call_count == 1
193-
assert connect_count == 1
194-
195-
conn.close()
102+
assert len(errors) == 1
103+
assert isinstance(errors[0], ProgrammingError)
104+
conn.close()

0 commit comments

Comments
 (0)