|
| 1 | +"""Pin: cancellation landing during the pre-ping path's shielded |
| 2 | +``conn.close()`` does NOT leak a reservation slot. |
| 3 | +
|
| 4 | +The pre-ping branch in ``ConnectionPool.acquire`` shields the dead- |
| 5 | +connection close with ``asyncio.shield`` so the close completes |
| 6 | +even if the outer caller is cancelled. The drain-idle and create |
| 7 | +calls each have their own ``except BaseException`` arms that release |
| 8 | +the reservation. This test pins the cumulative invariant: under |
| 9 | +rapid cancellation during the pre-ping path, ``pool._size`` does |
| 10 | +not drift above ``pool._max_size``, and a follow-up ``acquire`` |
| 11 | +still works. |
| 12 | +
|
| 13 | +The current implementation is correct (verified by tracing each |
| 14 | +exception path); this test guards against a refactor that re-orders |
| 15 | +the close/drain/create sequence and silently breaks the |
| 16 | +no-leak invariant. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import asyncio |
| 22 | +import contextlib |
| 23 | +from unittest.mock import AsyncMock, MagicMock |
| 24 | + |
| 25 | +import pytest |
| 26 | + |
| 27 | +from dqliteclient.pool import ConnectionPool |
| 28 | + |
| 29 | + |
| 30 | +def _alive_conn(*, dead: bool = False, slow_close: bool = False) -> MagicMock: |
| 31 | + """Build a MagicMock that mimics the slice of ``DqliteConnection`` |
| 32 | + the pool's acquire flow touches. ``slow_close=True`` makes |
| 33 | + ``close()`` await once so a cancel can land during the shielded |
| 34 | + body.""" |
| 35 | + conn = MagicMock() |
| 36 | + conn.is_connected = True |
| 37 | + if slow_close: |
| 38 | + |
| 39 | + async def _slow_close() -> None: |
| 40 | + await asyncio.sleep(0) |
| 41 | + |
| 42 | + conn.close = AsyncMock(side_effect=_slow_close) |
| 43 | + else: |
| 44 | + conn.close = AsyncMock() |
| 45 | + conn._pool_released = False |
| 46 | + proto = MagicMock() |
| 47 | + proto.is_wire_coherent = True |
| 48 | + transport = MagicMock() |
| 49 | + transport.is_closing.return_value = dead |
| 50 | + proto._writer = MagicMock() |
| 51 | + proto._writer.transport = transport |
| 52 | + proto._reader = MagicMock() |
| 53 | + proto._reader.at_eof.return_value = False |
| 54 | + conn._protocol = proto |
| 55 | + return conn |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.asyncio |
| 59 | +async def test_pre_ping_cancel_during_close_no_slot_leak( |
| 60 | + monkeypatch: pytest.MonkeyPatch, |
| 61 | +) -> None: |
| 62 | + """Cancel landing during the shielded ``conn.close()`` of the |
| 63 | + dead-conn path must release the reservation. After the cancel, |
| 64 | + pool size must be back to a state where new acquires succeed |
| 65 | + without exceeding max_size.""" |
| 66 | + pool = ConnectionPool(["a:9001"], min_size=0, max_size=1) |
| 67 | + monkeypatch.setattr("dqliteclient.pool.ConnectionPool.initialize", AsyncMock()) |
| 68 | + pool._initialized = True |
| 69 | + dead = _alive_conn(dead=True, slow_close=True) |
| 70 | + pool._pool.put_nowait(dead) |
| 71 | + pool._size = 1 |
| 72 | + |
| 73 | + fresh = _alive_conn(dead=False) |
| 74 | + |
| 75 | + async def fake_create(self: object) -> MagicMock: |
| 76 | + return fresh |
| 77 | + |
| 78 | + monkeypatch.setattr("dqliteclient.pool.ConnectionPool._create_connection", fake_create) |
| 79 | + |
| 80 | + async def caller() -> None: |
| 81 | + async with pool.acquire(): |
| 82 | + pass |
| 83 | + |
| 84 | + task = asyncio.create_task(caller()) |
| 85 | + # Yield enough times for the task to enter the pre-ping shield. |
| 86 | + for _ in range(5): |
| 87 | + await asyncio.sleep(0) |
| 88 | + task.cancel() |
| 89 | + with contextlib.suppress(asyncio.CancelledError): |
| 90 | + await task |
| 91 | + |
| 92 | + # Reservation must not leak. |
| 93 | + assert pool._size <= pool._max_size, ( |
| 94 | + f"reservation leaked under cancel: _size={pool._size} > max_size={pool._max_size}" |
| 95 | + ) |
| 96 | + |
| 97 | + # And the pool must still be acquireable. |
| 98 | + async with pool.acquire() as conn: |
| 99 | + assert conn is not None |
| 100 | + |
| 101 | + |
| 102 | +@pytest.mark.asyncio |
| 103 | +async def test_pre_ping_cancel_after_close_before_drain_no_leak( |
| 104 | + monkeypatch: pytest.MonkeyPatch, |
| 105 | +) -> None: |
| 106 | + """Cancel landing during ``_drain_idle`` (after the shielded close |
| 107 | + completes) must also release the reservation.""" |
| 108 | + pool = ConnectionPool(["a:9001"], min_size=0, max_size=1) |
| 109 | + monkeypatch.setattr("dqliteclient.pool.ConnectionPool.initialize", AsyncMock()) |
| 110 | + pool._initialized = True |
| 111 | + dead = _alive_conn(dead=True) # close completes synchronously |
| 112 | + pool._pool.put_nowait(dead) |
| 113 | + pool._size = 1 |
| 114 | + |
| 115 | + # _drain_idle is the first await after close; cancel during its |
| 116 | + # await releases via the except BaseException arm. |
| 117 | + async def slow_drain(self: object) -> None: |
| 118 | + await asyncio.sleep(0) |
| 119 | + |
| 120 | + monkeypatch.setattr("dqliteclient.pool.ConnectionPool._drain_idle", slow_drain) |
| 121 | + |
| 122 | + fresh = _alive_conn(dead=False) |
| 123 | + |
| 124 | + async def fake_create(self: object) -> MagicMock: |
| 125 | + return fresh |
| 126 | + |
| 127 | + monkeypatch.setattr("dqliteclient.pool.ConnectionPool._create_connection", fake_create) |
| 128 | + |
| 129 | + async def caller() -> None: |
| 130 | + async with pool.acquire(): |
| 131 | + pass |
| 132 | + |
| 133 | + task = asyncio.create_task(caller()) |
| 134 | + for _ in range(5): |
| 135 | + await asyncio.sleep(0) |
| 136 | + task.cancel() |
| 137 | + with contextlib.suppress(asyncio.CancelledError): |
| 138 | + await task |
| 139 | + |
| 140 | + assert pool._size <= pool._max_size |
| 141 | + async with pool.acquire() as conn: |
| 142 | + assert conn is not None |
0 commit comments