Skip to content

Commit 877c0a8

Browse files
Recover the queue slot when pool acquire is cancelled mid-wait
ConnectionPool.acquire() parks on asyncio.wait({get_task, closed_task}) when the pool is saturated. asyncio.wait does not cancel its argument tasks on outer cancellation, so when the parent coroutine is cancelled while suspended in that wait, `get_task` kept running. It could then win a subsequent queue.put() race, orphan the connection inside its own result, and silently shrink the pool's effective capacity — the connection was never reachable again, but _size still counted it. Restructure the wait block to run a cancellation-aware cleanup. On the cancel path, stop closed_task, and either: - if get_task already took a connection out of the queue, return it via put_nowait so the reservation backing it stays valid and the next acquirer can pick it up, or - if get_task is still pending, cancel and drain it. Re-raise to let cancellation propagate. The normal (non-cancel) path is unchanged: cancel closed_task, take get_task's result if it won, otherwise cancel the queue wait and re-check state. Add two regression tests: a single cancel/release cycle asserting that a subsequent acquire succeeds, and a 100-round fuzz loop asserting _size stays consistent. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 048c070 commit 877c0a8

File tree

2 files changed

+137
-1
lines changed

2 files changed

+137
-1
lines changed

src/dqliteclient/pool.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,31 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
344344
timeout=remaining,
345345
return_when=asyncio.FIRST_COMPLETED,
346346
)
347-
finally:
347+
except BaseException:
348+
# Outer cancellation of this coroutine while suspended in
349+
# ``asyncio.wait``. ``asyncio.wait`` does not cancel its
350+
# argument tasks, so both children are still alive — we
351+
# MUST stop them before propagating, otherwise the
352+
# abandoned get_task can win a later queue.put() and
353+
# orphan a connection (silently shrinking pool capacity).
348354
if not closed_task.done():
349355
closed_task.cancel()
356+
if get_task.done() and not get_task.cancelled() and get_task.exception() is None:
357+
# Outer cancel raced with a successful get. The
358+
# reservation that backed this connection is still
359+
# valid; return it to the queue so the next
360+
# acquirer can use it instead of closing and
361+
# releasing (which would shrink _size).
362+
conn_won = get_task.result()
363+
with contextlib.suppress(asyncio.QueueFull):
364+
self._pool.put_nowait(conn_won)
365+
elif not get_task.done():
366+
get_task.cancel()
367+
with contextlib.suppress(BaseException):
368+
await get_task
369+
raise
370+
if not closed_task.done():
371+
closed_task.cancel()
350372
if get_task in done:
351373
conn = get_task.result()
352374
else:

tests/test_pool.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,3 +1303,117 @@ async def test_ordinary_exception_is_absorbed_and_logged(
13031303
c2.close.assert_awaited_once()
13041304
assert pool._size == 0
13051305
assert any("boom" in rec.getMessage() for rec in caplog.records)
1306+
1307+
1308+
class TestAcquireCancellationPreservesCapacity:
1309+
"""An external cancellation of a coroutine parked in ``acquire()``'s
1310+
``asyncio.wait(...)`` must not leak a connection or shrink effective
1311+
pool capacity. The abandoned ``get_task`` could win a subsequent
1312+
``put()`` race; the fix must either cancel it or reclaim the
1313+
connection it took.
1314+
"""
1315+
1316+
@pytest.fixture
1317+
def mock_connection(self) -> MagicMock:
1318+
conn = MagicMock()
1319+
conn.is_connected = True
1320+
conn._in_transaction = False
1321+
conn._pool_released = False
1322+
conn.connect = AsyncMock()
1323+
conn.close = AsyncMock()
1324+
conn._protocol = MagicMock()
1325+
conn._protocol._writer = MagicMock()
1326+
conn._protocol._writer.transport = MagicMock()
1327+
conn._protocol._writer.transport.is_closing = MagicMock(return_value=False)
1328+
conn._protocol._reader = MagicMock()
1329+
conn._protocol._reader.at_eof = MagicMock(return_value=False)
1330+
return conn
1331+
1332+
async def test_cancelled_waiter_does_not_wedge_pool(
1333+
self,
1334+
mock_connection: MagicMock,
1335+
) -> None:
1336+
pool = ConnectionPool(["localhost:9001"], max_size=1, timeout=5.0)
1337+
1338+
with patch.object(pool._cluster, "connect", return_value=mock_connection):
1339+
await pool.initialize()
1340+
1341+
holder_released = asyncio.Event()
1342+
waiter_parked = asyncio.Event()
1343+
1344+
async def holder() -> None:
1345+
async with pool.acquire():
1346+
waiter_parked.set()
1347+
# Hold until the waiter has been cancelled AND we've
1348+
# been signalled to exit.
1349+
await holder_released.wait()
1350+
1351+
async def waiter() -> None:
1352+
async with pool.acquire():
1353+
pass
1354+
1355+
holder_task = asyncio.create_task(holder())
1356+
await waiter_parked.wait()
1357+
1358+
waiter_task = asyncio.create_task(waiter())
1359+
# Give waiter() time to enter acquire() and park in asyncio.wait.
1360+
await asyncio.sleep(0.05)
1361+
1362+
waiter_task.cancel()
1363+
with contextlib.suppress(asyncio.CancelledError):
1364+
await waiter_task
1365+
1366+
# Release the in-use connection; the post-cancel put() must
1367+
# not be stolen by the abandoned get_task.
1368+
holder_released.set()
1369+
await holder_task
1370+
1371+
# Pool must still accept a new acquire within a short window —
1372+
# if the abandoned task captured the connection, this would
1373+
# time out.
1374+
async with asyncio.timeout(1.0):
1375+
async with pool.acquire() as conn:
1376+
assert conn is mock_connection
1377+
1378+
assert pool._size == 1
1379+
1380+
await pool.close()
1381+
1382+
async def test_repeated_cancel_release_keeps_size_consistent(
1383+
self,
1384+
mock_connection: MagicMock,
1385+
) -> None:
1386+
"""Invariant: after any cancel/release cycle the ``_size`` counter
1387+
equals the number of connections actually reachable. Fuzz-test 100
1388+
rounds (smaller than the issue file's 1000 — kept light for CI)."""
1389+
pool = ConnectionPool(["localhost:9001"], max_size=1, timeout=5.0)
1390+
1391+
with patch.object(pool._cluster, "connect", return_value=mock_connection):
1392+
await pool.initialize()
1393+
1394+
for _ in range(100):
1395+
release_event = asyncio.Event()
1396+
1397+
async def holder(event: asyncio.Event = release_event) -> None:
1398+
async with pool.acquire():
1399+
await event.wait()
1400+
1401+
async def waiter() -> None:
1402+
async with pool.acquire():
1403+
pass
1404+
1405+
h = asyncio.create_task(holder())
1406+
await asyncio.sleep(0)
1407+
w = asyncio.create_task(waiter())
1408+
await asyncio.sleep(0.01)
1409+
w.cancel()
1410+
with contextlib.suppress(asyncio.CancelledError):
1411+
await w
1412+
release_event.set()
1413+
await h
1414+
1415+
# After every round the pool must have exactly one
1416+
# reachable reservation — queued or returned-by-holder.
1417+
assert pool._size == 1
1418+
1419+
await pool.close()

0 commit comments

Comments
 (0)