|
| 1 | +"""Pin: ``acquire()``'s capacity-wait timeout demux must not silently |
| 2 | +discard a connection that resolved ``get_task`` during the post-wait |
| 3 | +``await closed_task`` yield. |
| 4 | +
|
| 5 | +When ``asyncio.wait`` returns the timeout (``done == set()``) and the |
| 6 | +post-wait code cancels and awaits ``closed_task``, that await yields to |
| 7 | +the scheduler. A sibling ``_pool.put_nowait(conn)`` running at that |
| 8 | +yield resolves the still-pending ``get_task``. The original code's |
| 9 | +demux test ``if get_task in done`` uses the *snapshot* taken before the |
| 10 | +yield, so it incorrectly takes the else-arm: cancels (no-op on a done |
| 11 | +task), ``await get_task`` returns the connection, ``continue`` discards |
| 12 | +it. The reservation slot is never released because ``_release`` only |
| 13 | +fires for connections that flow back through the user's context |
| 14 | +manager. The pool permanently loses one slot of capacity per |
| 15 | +occurrence. |
| 16 | +
|
| 17 | +The fix replaces the snapshot-membership check with a live state check |
| 18 | +``get_task.done() and not get_task.cancelled() and get_task.exception() |
| 19 | +is None`` and routes a winning conn through the same put-back-or-release |
| 20 | +helper used by the existing ``except BaseException`` arm. |
| 21 | +""" |
| 22 | + |
| 23 | +from __future__ import annotations |
| 24 | + |
| 25 | +import asyncio |
| 26 | +from typing import Any |
| 27 | +from unittest.mock import MagicMock |
| 28 | + |
| 29 | +import pytest |
| 30 | + |
| 31 | +from dqliteclient.cluster import ClusterClient |
| 32 | +from dqliteclient.pool import ConnectionPool |
| 33 | + |
| 34 | + |
| 35 | +class _FakeConn: |
| 36 | + def __init__(self, name: str = "fake") -> None: |
| 37 | + self.name = name |
| 38 | + self._address = "localhost:9001" |
| 39 | + self._in_transaction = False |
| 40 | + self._tx_owner = None |
| 41 | + self._pool_released = False |
| 42 | + self._protocol = MagicMock() |
| 43 | + self._protocol._writer = MagicMock() |
| 44 | + self._protocol._writer.transport = MagicMock() |
| 45 | + self._protocol._writer.transport.is_closing = lambda: False |
| 46 | + self._protocol._reader = MagicMock() |
| 47 | + self._protocol._reader.at_eof = lambda: False |
| 48 | + self.close_called = False |
| 49 | + |
| 50 | + @property |
| 51 | + def is_connected(self) -> bool: |
| 52 | + return self._protocol is not None |
| 53 | + |
| 54 | + async def close(self) -> None: |
| 55 | + self.close_called = True |
| 56 | + self._protocol = None # type: ignore[assignment] |
| 57 | + |
| 58 | + |
| 59 | +def _make_pool() -> ConnectionPool: |
| 60 | + async def _connect(**_: Any) -> _FakeConn: |
| 61 | + return _FakeConn() |
| 62 | + |
| 63 | + cluster = MagicMock(spec=ClusterClient) |
| 64 | + cluster.connect = _connect |
| 65 | + return ConnectionPool( |
| 66 | + addresses=["localhost:9001"], |
| 67 | + min_size=0, |
| 68 | + max_size=1, |
| 69 | + timeout=0.1, |
| 70 | + cluster=cluster, |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.asyncio |
| 75 | +async def test_acquire_timeout_race_does_not_discard_late_winning_get_task() -> None: |
| 76 | + """``asyncio.wait`` times out (``done == set()``) but a sibling |
| 77 | + ``put_nowait`` resolves ``get_task`` during the post-wait |
| 78 | + ``await closed_task`` yield. The conn must be put back on the |
| 79 | + queue (or routed to the user), not silently discarded by the |
| 80 | + stale-snapshot demux. |
| 81 | +
|
| 82 | + Setup: simulate "pool already at max_size" by directly setting |
| 83 | + ``_size = 1`` (no real held connection). When ``acquire()`` enters |
| 84 | + the capacity-wait branch, our patched ``asyncio.wait`` drops a |
| 85 | + phantom connection into the queue synchronously before returning |
| 86 | + a timeout-shaped result (done=empty, both tasks still pending). |
| 87 | + The next loop iteration delivered to ``await closed_task`` then |
| 88 | + runs ``get_task.__step`` before our coroutine resumes, so |
| 89 | + ``get_task`` consumes the phantom and becomes done. The buggy |
| 90 | + post-wait demux at ``pool.py`` ``if get_task in done`` sees the |
| 91 | + stale empty snapshot, takes the else-arm, and discards the |
| 92 | + phantom on ``continue``. The fix re-checks ``get_task.done()`` |
| 93 | + live and routes the conn through put-back-or-release. |
| 94 | + """ |
| 95 | + pool = _make_pool() |
| 96 | + |
| 97 | + # Pretend max_size is reached so ``acquire()`` can't reserve and |
| 98 | + # enters the capacity-wait branch on its first iteration. Avoids |
| 99 | + # the asynccontextmanager-finalization noise from holding a real |
| 100 | + # acquire's slot across a bare ``__aenter__()``. |
| 101 | + pool._size = 1 |
| 102 | + |
| 103 | + phantom = _FakeConn(name="phantom") |
| 104 | + original_put_nowait = pool._pool.put_nowait |
| 105 | + |
| 106 | + import dqliteclient.pool as pool_mod |
| 107 | + |
| 108 | + real_wait = asyncio.wait |
| 109 | + call_count = 0 |
| 110 | + |
| 111 | + async def fake_wait( |
| 112 | + tasks: Any, *, timeout: Any = None, return_when: Any = None |
| 113 | + ) -> tuple[set[Any], set[Any]]: |
| 114 | + nonlocal call_count |
| 115 | + call_count += 1 |
| 116 | + if call_count == 1: |
| 117 | + # Drop the phantom into the queue while ``get_task`` has |
| 118 | + # not yet had its first ``__step`` run (we are inside |
| 119 | + # ``await asyncio.wait`` synchronously — the loop hasn't |
| 120 | + # iterated since ``create_task``). When the post-wait code |
| 121 | + # subsequently yields on ``await closed_task``, the loop |
| 122 | + # runs ``get_task.__step`` first and ``get_task`` consumes |
| 123 | + # ``phantom``, becoming done with a real result, before |
| 124 | + # our coroutine resumes. The post-wait demux's |
| 125 | + # ``if get_task in done`` then sees the stale empty |
| 126 | + # snapshot and routes ``get_task`` into the |
| 127 | + # cancel-and-discard arm. |
| 128 | + original_put_nowait(phantom) # type: ignore[arg-type] |
| 129 | + # Return timeout: done=empty, both tasks still pending |
| 130 | + # from the snapshot's perspective. |
| 131 | + return set(), set(tasks) |
| 132 | + # Subsequent calls: defer to the real wait so the deadline |
| 133 | + # actually consumes time and the loop terminates promptly. |
| 134 | + return await real_wait(tasks, timeout=timeout, return_when=return_when) |
| 135 | + |
| 136 | + pool_mod.asyncio.wait = fake_wait # type: ignore[attr-defined] |
| 137 | + received: object | None = None |
| 138 | + try: |
| 139 | + # With the fix: the live-state recheck after ``await |
| 140 | + # closed_task`` finds get_task done with phantom and routes it |
| 141 | + # to the user (no timeout). Without the fix: stale snapshot |
| 142 | + # demux drops phantom on the floor; subsequent iterations time |
| 143 | + # out with an empty queue. |
| 144 | + async with pool.acquire() as conn: |
| 145 | + received = conn |
| 146 | + finally: |
| 147 | + pool_mod.asyncio.wait = real_wait # type: ignore[attr-defined] |
| 148 | + |
| 149 | + # The phantom that ``put_nowait`` deposited during the |
| 150 | + # capacity-wait race must reach the user (or round-trip back to |
| 151 | + # the queue), never be silently discarded by the stale ``done`` |
| 152 | + # snapshot demux. |
| 153 | + assert received is phantom, ( |
| 154 | + f"acquire returned {received!r}, not the phantom that was put " |
| 155 | + "into the queue during the capacity-wait race — the post-wait " |
| 156 | + "demux's stale 'done' snapshot dropped phantom on the floor" |
| 157 | + ) |
| 158 | + |
| 159 | + # Reset _size to the value used to simulate at-capacity so close() |
| 160 | + # does not hit the underflow guard. ``_release`` already |
| 161 | + # decremented _size by routing through ``_release_reservation`` on |
| 162 | + # __aexit__. |
| 163 | + pool._size = 0 |
| 164 | + await pool.close() |
| 165 | + |
| 166 | + |
| 167 | +@pytest.mark.asyncio |
| 168 | +async def test_put_back_or_release_late_winner_queuefull_falls_back_to_close() -> None: |
| 169 | + """If the queue is full when the late-winner helper tries to put, |
| 170 | + it must close the conn and release the reservation rather than |
| 171 | + silently leak it. |
| 172 | +
|
| 173 | + The QueueFull branch represents an "impossible" reservation-vs- |
| 174 | + capacity violation; the helper must handle it without dropping |
| 175 | + the conn on the floor or skipping the ``_size`` decrement that |
| 176 | + wakes sibling acquirers. |
| 177 | + """ |
| 178 | + pool = _make_pool() |
| 179 | + |
| 180 | + # Pre-fill the bounded queue (max_size=1) so the helper's |
| 181 | + # put_nowait immediately raises QueueFull. |
| 182 | + pool._size = 1 |
| 183 | + occupant = _FakeConn(name="occupant") |
| 184 | + pool._pool.put_nowait(occupant) # type: ignore[arg-type] |
| 185 | + assert pool._pool.full() |
| 186 | + |
| 187 | + late_winner = _FakeConn(name="late_winner") |
| 188 | + await pool._put_back_or_release_late_winner(late_winner) # type: ignore[arg-type] |
| 189 | + |
| 190 | + # The late_winner must have been close()'d (because put_nowait |
| 191 | + # raised QueueFull, the helper falls back to close + release). |
| 192 | + assert late_winner.close_called is True |
| 193 | + |
| 194 | + # The reservation must have been released (size -= 1) so a |
| 195 | + # sibling acquirer can replace the slot. |
| 196 | + assert pool._size == 0 |
| 197 | + |
| 198 | + # Cleanup: drain the occupant from the queue. |
| 199 | + queued = pool._pool.get_nowait() |
| 200 | + assert queued is occupant |
| 201 | + await pool.close() |
0 commit comments