Skip to content

Commit ec7b3eb

Browse files
fix: close writer directly in _query_leader() to prevent socket leak
The try/finally in _query_leader() closed the protocol object, but if DqliteProtocol construction failed (or any error occurred between open_connection and the try block), the writer/socket was leaked. Widen the try/finally to cover protocol construction and close the writer directly instead of through the protocol object. Fixes #102 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6348b8e commit ec7b3eb

2 files changed

Lines changed: 50 additions & 4 deletions

File tree

src/dqliteclient/cluster.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ async def _query_leader(self, address: str) -> str | None:
7575
except (TimeoutError, OSError):
7676
return None
7777

78-
protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
79-
8078
try:
79+
protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
8180
await protocol.handshake()
8281
node_id, leader_addr = await protocol.get_leader()
8382

@@ -91,8 +90,8 @@ async def _query_leader(self, address: str) -> str | None:
9190
# node_id=0 and empty address: no leader known
9291
return None
9392
finally:
94-
protocol.close()
95-
await protocol.wait_closed()
93+
writer.close()
94+
await writer.wait_closed()
9695

9796
async def connect(self, database: str = "default") -> DqliteConnection:
9897
"""Connect to the cluster leader.

tests/test_cluster.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,53 @@ async def hang_forever(*args, **kwargs):
137137
):
138138
await client.find_leader()
139139

140+
async def test_query_leader_closes_writer_on_handshake_error(self) -> None:
141+
"""Writer must be closed even if handshake raises an unexpected error."""
142+
store = MemoryNodeStore(["localhost:9001"])
143+
client = ClusterClient(store, timeout=1.0)
144+
145+
mock_reader = AsyncMock()
146+
mock_writer = MagicMock()
147+
mock_writer.drain = AsyncMock()
148+
mock_writer.close = MagicMock()
149+
mock_writer.wait_closed = AsyncMock()
150+
151+
# Handshake data that triggers a protocol error
152+
mock_reader.read.side_effect = [b"\x00" * 64]
153+
154+
with (
155+
patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)),
156+
pytest.raises(ClusterError),
157+
):
158+
await client.find_leader()
159+
160+
# The writer must have been closed to avoid socket leak
161+
mock_writer.close.assert_called()
162+
163+
async def test_query_leader_closes_writer_on_protocol_init_error(self) -> None:
164+
"""Writer must be closed even if DqliteProtocol construction fails."""
165+
store = MemoryNodeStore(["localhost:9001"])
166+
client = ClusterClient(store, timeout=1.0)
167+
168+
mock_reader = AsyncMock()
169+
mock_writer = MagicMock()
170+
mock_writer.drain = AsyncMock()
171+
mock_writer.close = MagicMock()
172+
mock_writer.wait_closed = AsyncMock()
173+
174+
with (
175+
patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)),
176+
patch(
177+
"dqliteclient.cluster.DqliteProtocol",
178+
side_effect=RuntimeError("init failed"),
179+
),
180+
pytest.raises(ClusterError),
181+
):
182+
await client.find_leader()
183+
184+
# Even though DqliteProtocol() failed, the writer must be closed
185+
mock_writer.close.assert_called()
186+
140187
async def test_update_nodes(self) -> None:
141188
store = MemoryNodeStore()
142189
client = ClusterClient(store)

0 commit comments

Comments
 (0)