Skip to content

Commit 16df9d6

Browse files
fix: wake pool acquire() waiters promptly on close
pool.close() set _closed and drained idle connections but left tasks already blocked on the queue waiting up to 500 ms (the polling interval) to notice — effectively a silent close delay scaled by the number of waiters. Race the queue get against a _closed_event that close() now sets. Re-check _closed at the top of the acquire loop (and under the create branch's lock) so close takes effect immediately for in-flight waiters whether the event races them or they loop back around. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2603102 commit 16df9d6

2 files changed

Lines changed: 95 additions & 3 deletions

File tree

src/dqliteclient/pool.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
self._size = 0
8585
self._lock = asyncio.Lock()
8686
self._closed = False
87+
self._closed_event: asyncio.Event | None = None
8788
self._initialized = False
8889

8990
async def initialize(self) -> None:
@@ -102,6 +103,14 @@ async def _create_connection(self) -> DqliteConnection:
102103
self._size += 1
103104
return conn
104105

106+
def _get_closed_event(self) -> asyncio.Event:
107+
"""Lazily create the closed Event bound to the running loop."""
108+
if self._closed_event is None:
109+
self._closed_event = asyncio.Event()
110+
if self._closed:
111+
self._closed_event.set()
112+
return self._closed_event
113+
105114
async def _drain_idle(self) -> None:
106115
"""Close all idle connections in the pool.
107116
@@ -132,6 +141,9 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
132141
conn: DqliteConnection | None = None
133142

134143
while conn is None:
144+
if self._closed:
145+
raise DqliteConnectionError("Pool is closed")
146+
135147
# Try to get an idle connection from the queue
136148
try:
137149
conn = self._pool.get_nowait()
@@ -141,6 +153,8 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
141153

142154
# Try to create a new connection if under max
143155
async with self._lock:
156+
if self._closed:
157+
raise DqliteConnectionError("Pool is closed")
144158
if self._size < self._max_size:
145159
conn = await self._create_connection()
146160
break
@@ -153,10 +167,28 @@ async def acquire(self) -> AsyncIterator[DqliteConnection]:
153167
f"Timed out waiting for a connection from the pool "
154168
f"(max_size={self._max_size}, timeout={self._timeout}s)"
155169
)
170+
# Race the queue against the closed event so close() wakes
171+
# waiters promptly instead of leaving them on the polling loop.
172+
closed_event = self._get_closed_event()
173+
get_task: asyncio.Task[DqliteConnection] = asyncio.create_task(self._pool.get())
174+
closed_task = asyncio.create_task(closed_event.wait())
156175
try:
157-
conn = await asyncio.wait_for(self._pool.get(), timeout=min(remaining, 0.5))
158-
except TimeoutError:
159-
# Don't fail yet — loop back to re-check _size
176+
done, _pending = await asyncio.wait(
177+
{get_task, closed_task},
178+
timeout=min(remaining, 0.5),
179+
return_when=asyncio.FIRST_COMPLETED,
180+
)
181+
finally:
182+
if not closed_task.done():
183+
closed_task.cancel()
184+
if get_task in done:
185+
conn = get_task.result()
186+
else:
187+
# Either close fired or the poll timer fired; either way,
188+
# cancel the queue wait cleanly and let the loop re-check.
189+
get_task.cancel()
190+
with contextlib.suppress(BaseException):
191+
await get_task
160192
continue
161193

162194
# If connection is dead, discard and create a fresh one with leader discovery.
@@ -277,6 +309,8 @@ async def close(self) -> None:
277309
in-flight tasks before calling close().
278310
"""
279311
self._closed = True
312+
if self._closed_event is not None:
313+
self._closed_event.set()
280314
await self._drain_idle()
281315

282316
# In-use connections are closed by acquire()'s cleanup when they

tests/test_pool.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,64 @@ async def track_exec(*args, **kwargs):
736736
assert result is False
737737
assert not exec_called, "ROLLBACK must not be sent on a dead socket"
738738

739+
async def test_close_wakes_waiter_promptly(self) -> None:
740+
"""A task blocked in acquire() waiting for a connection must be
741+
woken quickly when pool.close() is called — not sit on the queue
742+
until its timeout expires.
743+
"""
744+
import asyncio
745+
import time
746+
747+
from dqliteclient.connection import DqliteConnection
748+
749+
pool = ConnectionPool(["localhost:9001"], max_size=1, timeout=5.0)
750+
751+
mock_conn = MagicMock(spec=DqliteConnection)
752+
mock_conn.is_connected = True
753+
mock_conn.close = AsyncMock()
754+
mock_conn._in_transaction = False
755+
mock_conn._in_use = False
756+
mock_conn._bound_loop = None
757+
mock_conn._pool_released = False
758+
mock_conn._check_in_use = MagicMock()
759+
760+
with patch.object(pool._cluster, "connect", return_value=mock_conn):
761+
await pool.initialize()
762+
763+
holder_acquired = asyncio.Event()
764+
release_holder = asyncio.Event()
765+
766+
async def hold_connection() -> None:
767+
async with pool.acquire():
768+
holder_acquired.set()
769+
await release_holder.wait()
770+
771+
holder = asyncio.create_task(hold_connection())
772+
await holder_acquired.wait()
773+
774+
async def try_acquire() -> BaseException | None:
775+
try:
776+
async with pool.acquire():
777+
return None
778+
except BaseException as e: # noqa: BLE001
779+
return e
780+
781+
waiter = asyncio.create_task(try_acquire())
782+
await asyncio.sleep(0.05)
783+
784+
t0 = time.monotonic()
785+
await pool.close()
786+
err = await asyncio.wait_for(waiter, timeout=1.0)
787+
elapsed = time.monotonic() - t0
788+
789+
release_holder.set()
790+
await holder
791+
792+
assert isinstance(err, DqliteConnectionError)
793+
assert elapsed < 0.3, (
794+
f"acquire() should wake within ~100ms of pool.close(); took {elapsed:.3f}s"
795+
)
796+
739797
async def test_reset_connection_returns_false_on_cancelled_error(self) -> None:
740798
"""_reset_connection must return False (not raise) when ROLLBACK is cancelled."""
741799
import asyncio

0 commit comments

Comments
 (0)