Skip to content

Commit 4b2ce9b

Browse files
Move cross-thread state access to event loop thread
- commit()/rollback(): Move _async_conn None-check into _commit_async/ _rollback_async which run on the event loop thread. Remove _get_async_connection() calls that would create spurious connections. - close(): Move _async_conn handling into _close_async() coroutine. Acquire _loop_lock for loop teardown to prevent races with _ensure_loop(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4bc128c commit 4b2ce9b

File tree

2 files changed

+96
-20
lines changed

2 files changed

+96
-20
lines changed

src/dqlitedbapi/connection.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,46 +118,50 @@ def close(self) -> None:
118118
return
119119
self._closed = True
120120
try:
121-
if self._async_conn is not None:
121+
if self._loop is not None and not self._loop.is_closed():
122122
with contextlib.suppress(Exception):
123-
self._run_sync(self._async_conn.close())
124-
self._async_conn = None
123+
self._run_sync(self._close_async())
125124
finally:
126-
if self._loop is not None and not self._loop.is_closed():
127-
self._loop.call_soon_threadsafe(self._loop.stop)
128-
if self._thread is not None:
129-
self._thread.join(timeout=5)
130-
self._loop.close()
131-
self._loop = None
132-
self._thread = None
125+
with self._loop_lock:
126+
if self._loop is not None and not self._loop.is_closed():
127+
self._loop.call_soon_threadsafe(self._loop.stop)
128+
if self._thread is not None:
129+
self._thread.join(timeout=5)
130+
self._loop.close()
131+
self._loop = None
132+
self._thread = None
133+
134+
async def _close_async(self) -> None:
135+
"""Async implementation of close -- runs on event loop thread."""
136+
if self._async_conn is not None:
137+
try:
138+
await self._async_conn.close()
139+
finally:
140+
self._async_conn = None
133141

134142
def commit(self) -> None:
135143
"""Commit any pending transaction."""
136144
self._check_thread()
137145
if self._closed:
138146
raise InterfaceError("Connection is closed")
139-
140-
if self._async_conn is not None:
141-
self._run_sync(self._commit_async())
147+
self._run_sync(self._commit_async())
142148

143149
async def _commit_async(self) -> None:
144150
"""Async implementation of commit."""
145-
conn = await self._get_async_connection()
146-
await conn.execute("COMMIT")
151+
if self._async_conn is not None:
152+
await self._async_conn.execute("COMMIT")
147153

148154
def rollback(self) -> None:
149155
"""Roll back any pending transaction."""
150156
self._check_thread()
151157
if self._closed:
152158
raise InterfaceError("Connection is closed")
153-
154-
if self._async_conn is not None:
155-
self._run_sync(self._rollback_async())
159+
self._run_sync(self._rollback_async())
156160

157161
async def _rollback_async(self) -> None:
158162
"""Async implementation of rollback."""
159-
conn = await self._get_async_connection()
160-
await conn.execute("ROLLBACK")
163+
if self._async_conn is not None:
164+
await self._async_conn.execute("ROLLBACK")
161165

162166
def cursor(self) -> Cursor:
163167
"""Return a new Cursor object."""
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Tests that commit/rollback don't create spurious connections.
2+
3+
If commit() or rollback() is called on a connection that was never used
4+
(no execute() was called), it should be a no-op. It should NOT create a
5+
new TCP connection just to send COMMIT/ROLLBACK.
6+
"""
7+
8+
from unittest.mock import AsyncMock, patch
9+
10+
from dqlitedbapi.connection import Connection
11+
12+
13+
class TestCommitNoSpuriousConnect:
14+
def test_commit_on_unused_connection_is_noop(self) -> None:
15+
"""commit() should not create a connection if none exists."""
16+
conn = Connection("localhost:9001", timeout=2.0)
17+
18+
with patch.object(conn, "_get_async_connection") as mock_get:
19+
mock_get.return_value = AsyncMock()
20+
conn.commit()
21+
mock_get.assert_not_called()
22+
23+
conn.close()
24+
25+
def test_rollback_on_unused_connection_is_noop(self) -> None:
26+
"""rollback() should not create a connection if none exists."""
27+
conn = Connection("localhost:9001", timeout=2.0)
28+
29+
with patch.object(conn, "_get_async_connection") as mock_get:
30+
mock_get.return_value = AsyncMock()
31+
conn.rollback()
32+
mock_get.assert_not_called()
33+
34+
conn.close()
35+
36+
37+
class TestCommitRollbackAsyncNoSpuriousConnect:
38+
def test_commit_async_does_not_call_get_async_connection(self) -> None:
39+
"""_commit_async should check _async_conn directly, not call _get_async_connection."""
40+
import ast
41+
import inspect
42+
import textwrap
43+
44+
source = textwrap.dedent(inspect.getsource(Connection._commit_async))
45+
tree = ast.parse(source)
46+
47+
for node in ast.walk(tree):
48+
if isinstance(node, ast.Call):
49+
func = node.func
50+
if isinstance(func, ast.Attribute) and func.attr == "_get_async_connection":
51+
raise AssertionError(
52+
"_commit_async calls _get_async_connection which creates "
53+
"new connections. It should check _async_conn directly."
54+
)
55+
56+
def test_rollback_async_does_not_call_get_async_connection(self) -> None:
57+
"""_rollback_async should check _async_conn directly, not call _get_async_connection."""
58+
import ast
59+
import inspect
60+
import textwrap
61+
62+
source = textwrap.dedent(inspect.getsource(Connection._rollback_async))
63+
tree = ast.parse(source)
64+
65+
for node in ast.walk(tree):
66+
if isinstance(node, ast.Call):
67+
func = node.func
68+
if isinstance(func, ast.Attribute) and func.attr == "_get_async_connection":
69+
raise AssertionError(
70+
"_rollback_async calls _get_async_connection which creates "
71+
"new connections. It should check _async_conn directly."
72+
)

0 commit comments

Comments
 (0)