Skip to content

Commit 0176fcf

Browse files
fix: add transaction task-ownership tracking to prevent cross-task injection
When a connection is in a transaction, another asyncio task could call execute() between the transaction owner's operations (when _in_use is temporarily False), silently injecting operations into the transaction. Track the owning task via asyncio.current_task() in the transaction() context manager. _check_in_use() now rejects operations from tasks that don't own the active transaction. Fixes #100 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5e57d96 commit 0176fcf

2 files changed

Lines changed: 75 additions & 0 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
self._in_transaction = False
8282
self._in_use = False
8383
self._bound_loop: asyncio.AbstractEventLoop | None = None
84+
self._tx_owner: asyncio.Task[Any] | None = None
8485

8586
@property
8687
def address(self) -> str:
@@ -172,6 +173,14 @@ def _check_in_use(self) -> None:
172173
"connection. DqliteConnection does not support concurrent coroutine "
173174
"access. Use a ConnectionPool to manage multiple concurrent operations."
174175
)
176+
if self._in_transaction and self._tx_owner is not None:
177+
current = asyncio.current_task()
178+
if current is not self._tx_owner:
179+
raise InterfaceError(
180+
"Cannot perform operation: connection is in a transaction owned "
181+
"by another task. Each task should use its own connection from "
182+
"the pool."
183+
)
175184

176185
def _invalidate(self) -> None:
177186
"""Mark the connection as broken after an unrecoverable error."""
@@ -286,9 +295,11 @@ async def transaction(self) -> AsyncIterator[None]:
286295
)
287296

288297
self._in_transaction = True
298+
self._tx_owner = asyncio.current_task()
289299
try:
290300
await self.execute("BEGIN")
291301
except BaseException:
302+
self._tx_owner = None
292303
self._in_transaction = False
293304
raise
294305

@@ -301,4 +312,5 @@ async def transaction(self) -> AsyncIterator[None]:
301312
await self.execute("ROLLBACK")
302313
raise
303314
finally:
315+
self._tx_owner = None
304316
self._in_transaction = False

tests/test_connection.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,3 +899,66 @@ async def use_conn():
899899
f"Expected InterfaceError, got {type(error_from_thread).__name__}: {error_from_thread}"
900900
)
901901
assert "event loop" in str(error_from_thread).lower()
902+
903+
async def test_other_task_rejected_during_transaction(self) -> None:
904+
"""Another task calling execute() during an active transaction must be rejected."""
905+
import asyncio
906+
907+
from dqliteclient.exceptions import InterfaceError
908+
909+
conn = DqliteConnection("localhost:9001")
910+
911+
mock_reader = AsyncMock()
912+
mock_writer = MagicMock()
913+
mock_writer.drain = AsyncMock()
914+
mock_writer.close = MagicMock()
915+
mock_writer.wait_closed = AsyncMock()
916+
917+
from dqlitewire.messages import DbResponse, WelcomeResponse
918+
919+
responses = [
920+
WelcomeResponse(heartbeat_timeout=15000).encode(),
921+
DbResponse(db_id=1).encode(),
922+
]
923+
mock_reader.read.side_effect = responses
924+
925+
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
926+
await conn.connect()
927+
928+
# Mock execute for the transaction owner (task A) — needs to work
929+
a_inside_tx = asyncio.Event()
930+
931+
async def mock_execute_a(sql: str, params=None):
932+
if sql not in ("BEGIN", "COMMIT", "ROLLBACK"):
933+
a_inside_tx.set()
934+
await asyncio.sleep(0) # yield to let task B run
935+
return (0, 0)
936+
937+
conn.execute = mock_execute_a # type: ignore[assignment]
938+
939+
errors: list[Exception] = []
940+
941+
async def task_a():
942+
async with conn.transaction():
943+
await conn.execute("INSERT INTO t VALUES (1)")
944+
await asyncio.sleep(0.1)
945+
946+
async def task_b():
947+
await a_inside_tx.wait()
948+
try:
949+
# Use _check_in_use directly — this is what real execute() calls
950+
conn._check_in_use()
951+
except InterfaceError as e:
952+
errors.append(e)
953+
954+
t_a = asyncio.create_task(task_a())
955+
t_b = asyncio.create_task(task_b())
956+
957+
await asyncio.gather(t_a, t_b, return_exceptions=True)
958+
959+
assert len(errors) == 1, (
960+
"Task B should have been rejected when trying to use a connection "
961+
"that is in a transaction owned by task A"
962+
)
963+
assert isinstance(errors[0], InterfaceError)
964+
assert "transaction" in str(errors[0]).lower()

0 commit comments

Comments
 (0)