Skip to content

Commit b87f3e0

Browse files
fix: catch BaseException in transaction() to ROLLBACK on CancelledError
Since Python 3.9, CancelledError inherits from BaseException, not Exception. The `except Exception:` handler in transaction() did not catch CancelledError, so ROLLBACK was never sent when a task was cancelled inside a transaction block. The finally block still reset _in_transaction to False, making the connection appear clean while the server-side transaction remained open. Change `except Exception:` to `except BaseException:` and update the inner `contextlib.suppress(Exception)` to `contextlib.suppress(BaseException)` to also handle cancellation during ROLLBACK itself. Closes #078 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b1f58e0 commit b87f3e0

2 files changed

Lines changed: 52 additions & 2 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ async def transaction(self) -> AsyncIterator[None]:
217217
try:
218218
yield
219219
await self.execute("COMMIT")
220-
except Exception:
220+
except BaseException:
221221
# Swallow rollback failure; original exception is more important
222-
with contextlib.suppress(Exception):
222+
with contextlib.suppress(BaseException):
223223
await self.execute("ROLLBACK")
224224
raise
225225
finally:

tests/test_connection.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,56 @@ async def test_fetchval_returns_none_for_empty(self) -> None:
443443
result = await conn.fetchval("SELECT id FROM t WHERE 1=0")
444444
assert result is None
445445

446+
async def test_transaction_rollback_on_cancellation(self) -> None:
447+
"""CancelledError inside a transaction must trigger ROLLBACK."""
448+
import asyncio
449+
450+
conn = DqliteConnection("localhost:9001")
451+
452+
mock_reader = AsyncMock()
453+
mock_writer = MagicMock()
454+
mock_writer.drain = AsyncMock()
455+
mock_writer.close = MagicMock()
456+
mock_writer.wait_closed = AsyncMock()
457+
458+
from dqlitewire.messages import DbResponse, WelcomeResponse
459+
460+
responses = [
461+
WelcomeResponse(heartbeat_timeout=15000).encode(),
462+
DbResponse(db_id=1).encode(),
463+
]
464+
mock_reader.read.side_effect = responses
465+
466+
with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)):
467+
await conn.connect()
468+
469+
# Track which SQL statements are executed
470+
call_log: list[str] = []
471+
472+
async def mock_execute(sql: str, params=None):
473+
call_log.append(sql)
474+
return (0, 0)
475+
476+
conn.execute = mock_execute # type: ignore[assignment]
477+
478+
async def cancelled_transaction():
479+
async with conn.transaction():
480+
await asyncio.sleep(10) # Will be cancelled here
481+
482+
task = asyncio.create_task(cancelled_transaction())
483+
await asyncio.sleep(0) # Let the task enter the transaction
484+
task.cancel()
485+
486+
with pytest.raises(asyncio.CancelledError):
487+
await task
488+
489+
# ROLLBACK must have been issued
490+
assert "ROLLBACK" in call_log, (
491+
f"ROLLBACK was not issued on CancelledError. Calls: {call_log}"
492+
)
493+
# _in_transaction must be cleaned up
494+
assert not conn._in_transaction
495+
446496
async def test_not_leader_error_invalidates_connection(self) -> None:
447497
"""OperationalError with 'not leader' code should invalidate the connection."""
448498
conn = DqliteConnection("localhost:9001")

0 commit comments

Comments
 (0)