Skip to content

Commit 8abd130

Browse files
Fix race condition in async connection initialization
Add asyncio.Lock to AsyncConnection._ensure_connection() to prevent concurrent callers from getting a half-initialized connection. Set _async_conn only after connect() succeeds. Add threading.Lock to Connection._ensure_loop() for the sync wrapper. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f9b71ba commit 8abd130

3 files changed

Lines changed: 72 additions & 11 deletions

File tree

src/dqlitedbapi/aio/connection.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Async connection implementation for dqlite."""
22

3+
import asyncio
34
from typing import Any
45

56
from dqliteclient import DqliteConnection
@@ -29,24 +30,33 @@ def __init__(
2930
self._timeout = timeout
3031
self._async_conn: DqliteConnection | None = None
3132
self._closed = False
33+
self._connect_lock = asyncio.Lock()
3234

3335
async def _ensure_connection(self) -> DqliteConnection:
3436
"""Ensure the underlying connection is established."""
3537
if self._closed:
3638
raise InterfaceError("Connection is closed")
3739

38-
if self._async_conn is None:
39-
self._async_conn = DqliteConnection(
40+
if self._async_conn is not None:
41+
return self._async_conn
42+
43+
async with self._connect_lock:
44+
# Double-check after acquiring lock
45+
if self._async_conn is not None:
46+
return self._async_conn
47+
48+
conn = DqliteConnection(
4049
self._address,
4150
database=self._database,
4251
timeout=self._timeout,
4352
)
4453
try:
45-
await self._async_conn.connect()
54+
await conn.connect()
4655
except Exception as e:
47-
self._async_conn = None
4856
raise OperationalError(f"Failed to connect: {e}") from e
4957

58+
self._async_conn = conn
59+
5060
return self._async_conn
5161

5262
async def connect(self) -> None:

src/dqlitedbapi/connection.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,21 @@ def __init__(
3333
self._closed = False
3434
self._loop: asyncio.AbstractEventLoop | None = None
3535
self._thread: threading.Thread | None = None
36+
self._loop_lock = threading.Lock()
3637

3738
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
3839
"""Ensure a dedicated event loop is running in a background thread.
3940
4041
This allows sync methods to work even when called from within
4142
an already-running async context (e.g. uvicorn).
4243
"""
43-
if self._loop is None or self._loop.is_closed():
44-
self._loop = asyncio.new_event_loop()
45-
self._thread = threading.Thread(target=self._loop.run_forever, daemon=True)
46-
self._thread.start()
44+
if self._loop is not None and not self._loop.is_closed():
45+
return self._loop
46+
with self._loop_lock:
47+
if self._loop is None or self._loop.is_closed():
48+
self._loop = asyncio.new_event_loop()
49+
self._thread = threading.Thread(target=self._loop.run_forever, daemon=True)
50+
self._thread.start()
4751
return self._loop
4852

4953
def _run_sync(self, coro: Any) -> Any:
@@ -62,16 +66,16 @@ async def _get_async_connection(self) -> DqliteConnection:
6266
raise InterfaceError("Connection is closed")
6367

6468
if self._async_conn is None:
65-
self._async_conn = DqliteConnection(
69+
conn = DqliteConnection(
6670
self._address,
6771
database=self._database,
6872
timeout=self._timeout,
6973
)
7074
try:
71-
await self._async_conn.connect()
75+
await conn.connect()
7276
except Exception as e:
73-
self._async_conn = None
7477
raise OperationalError(f"Failed to connect: {e}") from e
78+
self._async_conn = conn
7579

7680
return self._async_conn
7781

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Tests for race condition in async connection initialization."""
2+
3+
import asyncio
4+
from unittest.mock import AsyncMock, patch
5+
6+
import pytest
7+
8+
from dqlitedbapi.aio.connection import AsyncConnection
9+
10+
11+
class TestAsyncConnectionRace:
12+
@pytest.mark.asyncio
13+
async def test_concurrent_ensure_connection_waits_for_connect(self) -> None:
14+
"""A second _ensure_connection() call must not return before connect() finishes."""
15+
conn = AsyncConnection("localhost:9001")
16+
17+
connect_started = asyncio.Event()
18+
connect_finished = asyncio.Event()
19+
20+
async def slow_connect() -> None:
21+
connect_started.set()
22+
await asyncio.sleep(0.1) # Simulate slow TCP handshake
23+
connect_finished.set()
24+
25+
with patch("dqlitedbapi.aio.connection.DqliteConnection") as MockDqliteConn:
26+
mock_instance = AsyncMock()
27+
mock_instance.connect = slow_connect
28+
mock_instance._protocol = AsyncMock()
29+
mock_instance._db_id = 0
30+
MockDqliteConn.return_value = mock_instance
31+
32+
async def second_caller() -> bool:
33+
# Wait for first caller to start connecting
34+
await connect_started.wait()
35+
# Now call _ensure_connection — it should wait for connect to finish
36+
await conn._ensure_connection()
37+
# At this point, connect must have finished
38+
return connect_finished.is_set()
39+
40+
first_task = asyncio.create_task(conn._ensure_connection())
41+
second_task = asyncio.create_task(second_caller())
42+
43+
await first_task
44+
connect_was_finished = await second_task
45+
46+
# The second caller must have seen connect_finished=True
47+
assert connect_was_finished, "Second caller got connection before connect() finished"

0 commit comments

Comments
 (0)