Skip to content

Commit e261652

Browse files
Narrow _abort_protocol to suppress only expected drain errors
``contextlib.suppress(BaseException)`` caught CancelledError along with the TimeoutError from the bounded wait_closed, violating the structured concurrency contract: an outer ``asyncio.timeout`` scope that cancels mid-abort would observe the original connect failure instead of its own cancellation. Narrow to (TimeoutError, OSError) for the expected slow-drain / already-closed paths and DEBUG-log anything else. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ffa5bb9 commit e261652

2 files changed

Lines changed: 69 additions & 1 deletion

File tree

src/dqliteclient/connection.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextlib
5+
import logging
56
import math
67
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
78
from contextlib import asynccontextmanager
@@ -22,6 +23,8 @@
2223
from dqlitewire import LEADER_ERROR_CODES as _LEADER_ERROR_CODES
2324
from dqlitewire.exceptions import EncodeError as _WireEncodeError
2425

26+
logger = logging.getLogger(__name__)
27+
2528

2629
def _parse_address(address: str) -> tuple[str, int]:
2730
"""Parse a host:port address string, handling IPv6 brackets."""
@@ -219,8 +222,19 @@ async def _abort_protocol(self) -> None:
219222
return
220223
self._protocol = None
221224
protocol.close()
222-
with contextlib.suppress(BaseException):
225+
# Narrow the suppression: a bounded wait on the transport drain
226+
# can legitimately raise TimeoutError (slow peer) or OSError
227+
# (already-closed writer). Anything else — especially
228+
# CancelledError from an outer ``asyncio.timeout`` scope — must
229+
# propagate so structured-concurrency cancellation semantics
230+
# remain intact. DEBUG-log an unexpected Exception for
231+
# diagnostics; do not swallow.
232+
try:
223233
await asyncio.wait_for(protocol.wait_closed(), timeout=0.5)
234+
except (TimeoutError, OSError):
235+
pass
236+
except Exception: # pragma: no cover
237+
logger.debug("_abort_protocol: unexpected drain error", exc_info=True)
224238

225239
async def __aenter__(self) -> "DqliteConnection":
226240
await self.connect()

tests/test_connection.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,3 +981,57 @@ async def task_b():
981981
)
982982
assert isinstance(errors[0], InterfaceError)
983983
assert "transaction" in str(errors[0]).lower()
984+
985+
986+
class TestAbortProtocolNarrowSuppression:
987+
"""_abort_protocol must let CancelledError propagate so an outer
988+
``asyncio.timeout`` scope sees the cancellation instead of being
989+
swallowed along with the TimeoutError from the bounded drain.
990+
"""
991+
992+
async def test_outer_cancel_propagates_through_abort(self) -> None:
993+
import asyncio
994+
from unittest.mock import MagicMock
995+
996+
import pytest
997+
998+
from dqliteclient.connection import DqliteConnection
999+
1000+
conn = DqliteConnection("localhost:9001", database="x", timeout=1.0)
1001+
proto = MagicMock()
1002+
proto.close = MagicMock()
1003+
1004+
# wait_closed hangs forever; we wrap the abort in wait_for with
1005+
# a short outer timeout and assert the outer TimeoutError
1006+
# propagates (in the previous BaseException-suppressing code
1007+
# it would have been eaten).
1008+
async def hang_forever() -> None:
1009+
await asyncio.sleep(999)
1010+
1011+
proto.wait_closed = hang_forever
1012+
conn._protocol = proto
1013+
1014+
with pytest.raises(TimeoutError):
1015+
await asyncio.wait_for(conn._abort_protocol(), timeout=0.1)
1016+
1017+
async def test_timeout_during_drain_is_suppressed(self) -> None:
1018+
"""The bounded wait_closed budget *internally* expiring is
1019+
expected (slow peer) and must not propagate."""
1020+
import asyncio
1021+
from unittest.mock import MagicMock
1022+
1023+
from dqliteclient.connection import DqliteConnection
1024+
1025+
conn = DqliteConnection("localhost:9001", database="x", timeout=1.0)
1026+
proto = MagicMock()
1027+
proto.close = MagicMock()
1028+
1029+
async def hang_forever() -> None:
1030+
await asyncio.sleep(999)
1031+
1032+
proto.wait_closed = hang_forever
1033+
conn._protocol = proto
1034+
1035+
# No outer timeout: the internal 0.5s wait_for expires, the
1036+
# resulting TimeoutError is suppressed, and the call returns.
1037+
await conn._abort_protocol()

0 commit comments

Comments
 (0)