Skip to content

Commit 1722813

Browse files
Use live get_task state in pool acquire post-wait demux
The capacity-wait branch's post-wait demux used ``if get_task in done:``, but ``done`` is the snapshot taken before the post-wait ``await closed_task`` yield. A sibling ``_release`` ``put_nowait`` during that yield resolves ``get_task``; the stale snapshot still says False, so the else-arm cancels (no-op on a done task) and ``await get_task`` returns the conn — silently discarded by ``continue``. The reservation slot leaks because ``_release`` only fires for connections that flow back through the user's context manager. Replace the snapshot membership test with a live-state predicate (``done() and not cancelled() and exception() is None``) and add a second live-state recheck after the cancel-and-await in the else-arm. Factor the existing put_nowait-or-close-and-release fallback (already present in the ``except BaseException`` arm) into a private helper ``_put_back_or_release_late_winner`` so both call sites share one code path.
1 parent 27150c8 commit 1722813

2 files changed

Lines changed: 260 additions & 25 deletions

File tree

src/dqliteclient/pool.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,48 @@ async def _create_connection(self) -> DqliteConnection:
441441
max_attempts=self._max_attempts,
442442
)
443443

444+
async def _put_back_or_release_late_winner(self, conn: DqliteConnection) -> None:
445+
"""Put a connection back on the queue, or close + release the
446+
reservation if the queue is full.
447+
448+
Used in two places in ``acquire()``:
449+
450+
1. The ``except BaseException`` arm — outer cancel raced with
451+
a successful ``get_task`` (a sibling ``_release``
452+
``put_nowait`` ran between our snapshot and the cancel).
453+
2. The post-wait demux's else-arm — timeout snapshot raced a
454+
winning ``get_task`` during the post-wait
455+
``await closed_task``.
456+
457+
Without this routing, the conn is referenced only by the
458+
soon-to-be-GC'd ``get_task`` and silently disappears. Its
459+
reservation slot is never released because ``_release`` only
460+
fires for connections that flow back through the user's
461+
context manager — so the pool permanently loses one slot of
462+
capacity per occurrence.
463+
"""
464+
try:
465+
self._pool.put_nowait(conn)
466+
except asyncio.QueueFull:
467+
# Invariant violation: reservations should track queue
468+
# capacity exactly, so a full queue on return is
469+
# "impossible." If it happens anyway, silently dropping
470+
# the reference would leak a live reader task and a
471+
# socket. Close explicitly and adjust the reservation
472+
# count so the pool shrinks cleanly instead of leaking.
473+
# Suppression of close's own errors is narrow — OSError on
474+
# an already-dead writer is expected; anything else
475+
# propagates.
476+
with contextlib.suppress(OSError):
477+
await conn.close()
478+
# Route through the helper so the counter stays
479+
# lock-protected and sibling acquirers parked on
480+
# ``closed_event.wait()`` get woken via
481+
# ``_signal_state_change``. Shield so a nested cancel
482+
# cannot leave ``_size`` inconsistent.
483+
with contextlib.suppress(asyncio.CancelledError):
484+
await asyncio.shield(self._release_reservation())
485+
444486
async def _release_reservation(self) -> None:
445487
"""Decrement ``_size`` under the lock, waking waiters.
446488
@@ -704,30 +746,7 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
704746
# valid; return it to the queue so the next
705747
# acquirer can use it instead of closing and
706748
# releasing (which would shrink _size).
707-
conn_won = get_task.result()
708-
try:
709-
self._pool.put_nowait(conn_won)
710-
except asyncio.QueueFull:
711-
# Invariant violation: reservations should track
712-
# queue capacity exactly, so a full queue on
713-
# return is "impossible." If it happens anyway,
714-
# silently dropping the reference would leak a
715-
# live reader task and a socket. Close
716-
# explicitly and adjust the reservation count
717-
# so the pool shrinks cleanly instead of
718-
# leaking. Suppression of close's own errors is
719-
# narrow — OSError on an already-dead writer is
720-
# expected; anything else propagates.
721-
with contextlib.suppress(OSError):
722-
await conn_won.close()
723-
# Route through the helper so the counter
724-
# stays lock-protected and sibling acquirers
725-
# parked on ``closed_event.wait()`` get
726-
# woken via ``_signal_state_change``. Shield
727-
# so a nested cancel cannot leave ``_size``
728-
# inconsistent.
729-
with contextlib.suppress(asyncio.CancelledError):
730-
await asyncio.shield(self._release_reservation())
749+
await self._put_back_or_release_late_winner(get_task.result())
731750
elif get_task is not None and not get_task.done():
732751
get_task.cancel()
733752
with contextlib.suppress(asyncio.CancelledError):
@@ -755,14 +774,29 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
755774
# must still propagate.
756775
with contextlib.suppress(asyncio.CancelledError):
757776
await closed_task
758-
if get_task in done:
777+
if get_task.done() and not get_task.cancelled() and get_task.exception() is None:
778+
# Live-state check: ``done`` is the snapshot taken
779+
# before the post-wait ``await closed_task`` yield
780+
# above, during which a sibling ``_release`` can
781+
# ``put_nowait`` and resolve ``get_task``. Trusting
782+
# ``get_task in done`` would silently route the
783+
# winning conn into the cancel-and-discard arm,
784+
# leaking one slot of capacity per occurrence.
759785
conn = get_task.result()
760786
else:
761787
# Either close fired or the poll timer fired; either way,
762788
# cancel the queue wait cleanly and let the loop re-check.
763789
get_task.cancel()
764790
with contextlib.suppress(asyncio.CancelledError):
765791
await get_task
792+
if get_task.done() and not get_task.cancelled() and get_task.exception() is None:
793+
# Cancel raced a successful get during the await
794+
# above (a sibling ``_release`` put_nowait between
795+
# our cancel call and the cancel actually
796+
# delivering). Route the conn back via the same
797+
# put-back-or-release path the outer-cancel arm
798+
# uses, so the slot is not leaked.
799+
await self._put_back_or_release_late_winner(get_task.result())
766800
continue
767801

768802
# If connection is dead, discard and create a fresh one with leader discovery.
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Pin: ``acquire()``'s capacity-wait timeout demux must not silently
2+
discard a connection that resolved ``get_task`` during the post-wait
3+
``await closed_task`` yield.
4+
5+
When ``asyncio.wait`` returns the timeout (``done == set()``) and the
6+
post-wait code cancels and awaits ``closed_task``, that await yields to
7+
the scheduler. A sibling ``_pool.put_nowait(conn)`` running at that
8+
yield resolves the still-pending ``get_task``. The original code's
9+
demux test ``if get_task in done`` uses the *snapshot* taken before the
10+
yield, so it incorrectly takes the else-arm: cancels (no-op on a done
11+
task), ``await get_task`` returns the connection, ``continue`` discards
12+
it. The reservation slot is never released because ``_release`` only
13+
fires for connections that flow back through the user's context
14+
manager. The pool permanently loses one slot of capacity per
15+
occurrence.
16+
17+
The fix replaces the snapshot-membership check with a live state check
18+
``get_task.done() and not get_task.cancelled() and get_task.exception()
19+
is None`` and routes a winning conn through the same put-back-or-release
20+
helper used by the existing ``except BaseException`` arm.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import asyncio
26+
from typing import Any
27+
from unittest.mock import MagicMock
28+
29+
import pytest
30+
31+
from dqliteclient.cluster import ClusterClient
32+
from dqliteclient.pool import ConnectionPool
33+
34+
35+
class _FakeConn:
36+
def __init__(self, name: str = "fake") -> None:
37+
self.name = name
38+
self._address = "localhost:9001"
39+
self._in_transaction = False
40+
self._tx_owner = None
41+
self._pool_released = False
42+
self._protocol = MagicMock()
43+
self._protocol._writer = MagicMock()
44+
self._protocol._writer.transport = MagicMock()
45+
self._protocol._writer.transport.is_closing = lambda: False
46+
self._protocol._reader = MagicMock()
47+
self._protocol._reader.at_eof = lambda: False
48+
self.close_called = False
49+
50+
@property
51+
def is_connected(self) -> bool:
52+
return self._protocol is not None
53+
54+
async def close(self) -> None:
55+
self.close_called = True
56+
self._protocol = None # type: ignore[assignment]
57+
58+
59+
def _make_pool() -> ConnectionPool:
60+
async def _connect(**_: Any) -> _FakeConn:
61+
return _FakeConn()
62+
63+
cluster = MagicMock(spec=ClusterClient)
64+
cluster.connect = _connect
65+
return ConnectionPool(
66+
addresses=["localhost:9001"],
67+
min_size=0,
68+
max_size=1,
69+
timeout=0.1,
70+
cluster=cluster,
71+
)
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_acquire_timeout_race_does_not_discard_late_winning_get_task() -> None:
76+
"""``asyncio.wait`` times out (``done == set()``) but a sibling
77+
``put_nowait`` resolves ``get_task`` during the post-wait
78+
``await closed_task`` yield. The conn must be put back on the
79+
queue (or routed to the user), not silently discarded by the
80+
stale-snapshot demux.
81+
82+
Setup: simulate "pool already at max_size" by directly setting
83+
``_size = 1`` (no real held connection). When ``acquire()`` enters
84+
the capacity-wait branch, our patched ``asyncio.wait`` drops a
85+
phantom connection into the queue synchronously before returning
86+
a timeout-shaped result (done=empty, both tasks still pending).
87+
The next loop iteration delivered to ``await closed_task`` then
88+
runs ``get_task.__step`` before our coroutine resumes, so
89+
``get_task`` consumes the phantom and becomes done. The buggy
90+
post-wait demux at ``pool.py`` ``if get_task in done`` sees the
91+
stale empty snapshot, takes the else-arm, and discards the
92+
phantom on ``continue``. The fix re-checks ``get_task.done()``
93+
live and routes the conn through put-back-or-release.
94+
"""
95+
pool = _make_pool()
96+
97+
# Pretend max_size is reached so ``acquire()`` can't reserve and
98+
# enters the capacity-wait branch on its first iteration. Avoids
99+
# the asynccontextmanager-finalization noise from holding a real
100+
# acquire's slot across a bare ``__aenter__()``.
101+
pool._size = 1
102+
103+
phantom = _FakeConn(name="phantom")
104+
original_put_nowait = pool._pool.put_nowait
105+
106+
import dqliteclient.pool as pool_mod
107+
108+
real_wait = asyncio.wait
109+
call_count = 0
110+
111+
async def fake_wait(
112+
tasks: Any, *, timeout: Any = None, return_when: Any = None
113+
) -> tuple[set[Any], set[Any]]:
114+
nonlocal call_count
115+
call_count += 1
116+
if call_count == 1:
117+
# Drop the phantom into the queue while ``get_task`` has
118+
# not yet had its first ``__step`` run (we are inside
119+
# ``await asyncio.wait`` synchronously — the loop hasn't
120+
# iterated since ``create_task``). When the post-wait code
121+
# subsequently yields on ``await closed_task``, the loop
122+
# runs ``get_task.__step`` first and ``get_task`` consumes
123+
# ``phantom``, becoming done with a real result, before
124+
# our coroutine resumes. The post-wait demux's
125+
# ``if get_task in done`` then sees the stale empty
126+
# snapshot and routes ``get_task`` into the
127+
# cancel-and-discard arm.
128+
original_put_nowait(phantom) # type: ignore[arg-type]
129+
# Return timeout: done=empty, both tasks still pending
130+
# from the snapshot's perspective.
131+
return set(), set(tasks)
132+
# Subsequent calls: defer to the real wait so the deadline
133+
# actually consumes time and the loop terminates promptly.
134+
return await real_wait(tasks, timeout=timeout, return_when=return_when)
135+
136+
pool_mod.asyncio.wait = fake_wait # type: ignore[attr-defined]
137+
received: object | None = None
138+
try:
139+
# With the fix: the live-state recheck after ``await
140+
# closed_task`` finds get_task done with phantom and routes it
141+
# to the user (no timeout). Without the fix: stale snapshot
142+
# demux drops phantom on the floor; subsequent iterations time
143+
# out with an empty queue.
144+
async with pool.acquire() as conn:
145+
received = conn
146+
finally:
147+
pool_mod.asyncio.wait = real_wait # type: ignore[attr-defined]
148+
149+
# The phantom that ``put_nowait`` deposited during the
150+
# capacity-wait race must reach the user (or round-trip back to
151+
# the queue), never be silently discarded by the stale ``done``
152+
# snapshot demux.
153+
assert received is phantom, (
154+
f"acquire returned {received!r}, not the phantom that was put "
155+
"into the queue during the capacity-wait race — the post-wait "
156+
"demux's stale 'done' snapshot dropped phantom on the floor"
157+
)
158+
159+
# Reset _size to the value used to simulate at-capacity so close()
160+
# does not hit the underflow guard. ``_release`` already
161+
# decremented _size by routing through ``_release_reservation`` on
162+
# __aexit__.
163+
pool._size = 0
164+
await pool.close()
165+
166+
167+
@pytest.mark.asyncio
168+
async def test_put_back_or_release_late_winner_queuefull_falls_back_to_close() -> None:
169+
"""If the queue is full when the late-winner helper tries to put,
170+
it must close the conn and release the reservation rather than
171+
silently leak it.
172+
173+
The QueueFull branch represents an "impossible" reservation-vs-
174+
capacity violation; the helper must handle it without dropping
175+
the conn on the floor or skipping the ``_size`` decrement that
176+
wakes sibling acquirers.
177+
"""
178+
pool = _make_pool()
179+
180+
# Pre-fill the bounded queue (max_size=1) so the helper's
181+
# put_nowait immediately raises QueueFull.
182+
pool._size = 1
183+
occupant = _FakeConn(name="occupant")
184+
pool._pool.put_nowait(occupant) # type: ignore[arg-type]
185+
assert pool._pool.full()
186+
187+
late_winner = _FakeConn(name="late_winner")
188+
await pool._put_back_or_release_late_winner(late_winner) # type: ignore[arg-type]
189+
190+
# The late_winner must have been close()'d (because put_nowait
191+
# raised QueueFull, the helper falls back to close + release).
192+
assert late_winner.close_called is True
193+
194+
# The reservation must have been released (size -= 1) so a
195+
# sibling acquirer can replace the slot.
196+
assert pool._size == 0
197+
198+
# Cleanup: drain the occupant from the queue.
199+
queued = pool._pool.get_nowait()
200+
assert queued is occupant
201+
await pool.close()

0 commit comments

Comments
 (0)