Skip to content

Commit 8fbec2d

Browse files
fix: set _in_transaction before await to produce correct error on concurrent entry
The _in_transaction flag was set AFTER `await self.execute("BEGIN")`, creating a TOCTOU window where a second coroutine could pass the _in_transaction check during the await and enter the transaction block. The second coroutine would then get a confusing InterfaceError about concurrent access instead of the correct OperationalError about nested transactions. Move the flag set before the await, with a try/except to reset it if BEGIN fails. Fixes #096 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a278765 commit 8fbec2d

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

src/dqliteclient/connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,12 @@ async def transaction(self) -> AsyncIterator[None]:
258258
0, "Nested transactions are not supported; use SAVEPOINT directly"
259259
)
260260

261-
await self.execute("BEGIN")
262261
self._in_transaction = True
262+
try:
263+
await self.execute("BEGIN")
264+
except BaseException:
265+
self._in_transaction = False
266+
raise
263267

264268
try:
265269
yield

tests/test_connection.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,70 @@ async def do_execute():
775775
task.cancel()
776776
with pytest.raises(asyncio.CancelledError):
777777
await task
778+
779+
async def test_concurrent_transaction_raises_operational_error(self) -> None:
780+
"""Second concurrent transaction() must raise OperationalError, not InterfaceError."""
781+
import asyncio
782+
783+
from dqliteclient.exceptions import InterfaceError, OperationalError
784+
785+
conn = DqliteConnection("localhost:9001")
786+
787+
mock_reader = AsyncMock()
788+
mock_writer = MagicMock()
789+
mock_writer.drain = AsyncMock()
790+
mock_writer.close = MagicMock()
791+
mock_writer.wait_closed = AsyncMock()
792+
793+
from dqlitewire.messages import DbResponse, WelcomeResponse
794+
795+
responses = [
796+
WelcomeResponse(heartbeat_timeout=15000).encode(),
797+
DbResponse(db_id=1).encode(),
798+
]
799+
mock_reader.read.side_effect = responses
800+
801+
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
802+
await conn.connect()
803+
804+
# Mock execute to track calls and allow concurrent entry
805+
begin_entered = asyncio.Event()
806+
807+
async def mock_execute(sql: str, params=None):
808+
if sql == "BEGIN":
809+
begin_entered.set()
810+
await asyncio.sleep(0) # yield to let second coroutine enter
811+
return (0, 0)
812+
813+
conn.execute = mock_execute # type: ignore[assignment]
814+
815+
errors: list[Exception] = []
816+
817+
async def tx_a():
818+
async with conn.transaction():
819+
await asyncio.sleep(1)
820+
821+
async def tx_b():
822+
await begin_entered.wait()
823+
try:
824+
async with conn.transaction():
825+
pass
826+
except (OperationalError, InterfaceError) as e:
827+
errors.append(e)
828+
829+
task_a = asyncio.create_task(tx_a())
830+
task_b = asyncio.create_task(tx_b())
831+
832+
await task_b # should raise OperationalError about nested transactions
833+
834+
task_a.cancel()
835+
with pytest.raises(asyncio.CancelledError):
836+
await task_a
837+
838+
assert len(errors) == 1
839+
# Must get OperationalError about nested transactions, NOT InterfaceError
840+
assert isinstance(errors[0], OperationalError), (
841+
f"Expected OperationalError about nested transactions, "
842+
f"got {type(errors[0]).__name__}: {errors[0]}"
843+
)
844+
assert "Nested" in str(errors[0]) or "nested" in str(errors[0])

0 commit comments

Comments
 (0)