Skip to content

Commit 1edc2b5

Browse files
fix: protect connect() with _in_use guard to prevent TOCTOU race
connect() was not protected by the _in_use re-entrance guard added in #077. Two coroutines calling connect() on the same DqliteConnection could both pass the `if self._protocol is not None: return` check, both open TCP connections, and the second would overwrite self._protocol, permanently leaking the first connection's socket. Add _check_in_use() and _in_use flag to connect(), consistent with the existing pattern in execute/fetch/fetchall/fetchval. This turns a silent socket leak into a loud InterfaceError. Closes #091 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 91521cd commit 1edc2b5

File tree

2 files changed

+80
-20
lines changed

2 files changed

+80
-20
lines changed

src/dqliteclient/connection.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,30 +87,39 @@ def is_connected(self) -> bool:
8787

8888
async def connect(self) -> None:
8989
"""Establish connection to the database."""
90+
self._check_in_use()
9091
if self._protocol is not None:
9192
return
9293

93-
host, port = _parse_address(self._address)
94-
95-
try:
96-
reader, writer = await asyncio.wait_for(
97-
asyncio.open_connection(host, port),
98-
timeout=self._timeout,
99-
)
100-
except TimeoutError as e:
101-
raise DqliteConnectionError(f"Connection to {self._address} timed out") from e
102-
except OSError as e:
103-
raise DqliteConnectionError(f"Failed to connect to {self._address}: {e}") from e
104-
105-
self._protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
106-
94+
self._in_use = True
10795
try:
108-
await self._protocol.handshake()
109-
self._db_id = await self._protocol.open_database(self._database)
110-
except BaseException:
111-
self._protocol.close()
112-
self._protocol = None
113-
raise
96+
host, port = _parse_address(self._address)
97+
98+
try:
99+
reader, writer = await asyncio.wait_for(
100+
asyncio.open_connection(host, port),
101+
timeout=self._timeout,
102+
)
103+
except TimeoutError as e:
104+
raise DqliteConnectionError(
105+
f"Connection to {self._address} timed out"
106+
) from e
107+
except OSError as e:
108+
raise DqliteConnectionError(
109+
f"Failed to connect to {self._address}: {e}"
110+
) from e
111+
112+
self._protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
113+
114+
try:
115+
await self._protocol.handshake()
116+
self._db_id = await self._protocol.open_database(self._database)
117+
except BaseException:
118+
self._protocol.close()
119+
self._protocol = None
120+
raise
121+
finally:
122+
self._in_use = False
114123

115124
async def close(self) -> None:
116125
"""Close the connection."""

tests/test_connection.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,3 +675,54 @@ async def second_execute():
675675

676676
assert len(errors) == 1
677677
assert "another operation is in progress" in str(errors[0])
678+
679+
async def test_concurrent_connect_raises_interface_error(self) -> None:
680+
"""Two coroutines connecting the same object must raise InterfaceError."""
681+
import asyncio
682+
683+
from dqliteclient.exceptions import InterfaceError
684+
685+
conn = DqliteConnection("localhost:9001")
686+
687+
gate = asyncio.Event()
688+
689+
async def slow_open(*args, **kwargs):
690+
await gate.wait()
691+
reader = AsyncMock()
692+
writer = MagicMock()
693+
writer.drain = AsyncMock()
694+
writer.close = MagicMock()
695+
writer.wait_closed = AsyncMock()
696+
return reader, writer
697+
698+
errors: list[Exception] = []
699+
700+
async def first_connect():
701+
with (
702+
patch("asyncio.open_connection", side_effect=slow_open),
703+
patch("dqliteclient.connection.DqliteProtocol") as MockProto,
704+
):
705+
proto = MagicMock()
706+
proto.handshake = AsyncMock()
707+
proto.open_database = AsyncMock(return_value=1)
708+
proto.close = MagicMock()
709+
MockProto.return_value = proto
710+
await conn.connect()
711+
712+
async def second_connect():
713+
await asyncio.sleep(0) # Let first_connect start
714+
try:
715+
await conn.connect()
716+
except InterfaceError as e:
717+
errors.append(e)
718+
719+
task1 = asyncio.create_task(first_connect())
720+
task2 = asyncio.create_task(second_connect())
721+
722+
await asyncio.sleep(0) # Let both start
723+
gate.set()
724+
725+
await asyncio.gather(task1, task2, return_exceptions=True)
726+
727+
assert len(errors) == 1
728+
assert "another operation is in progress" in str(errors[0])

0 commit comments

Comments
 (0)