Skip to content

Commit a5fc30c

Browse files
Chain find_leader per-node failures via BaseExceptionGroup when more than one node fails
``ClusterClient._find_leader_impl`` overwrote ``last_exc`` on every iteration; the final ``raise ClusterError(...) from last_exc`` chained only the LAST iteration's exception even though the message listed ALL N nodes' failures. Code branching on ``e.__cause__`` type (e.g. routing security alerts for ``ProtocolError``-from-malformed-redirect) saw a non- deterministic decision based on iteration ordering — depending on the random shuffle, a security-relevant ``ProtocolError`` on node 1 was overwritten by a benign ``TimeoutError`` on node 3. Collect every per-node exception into a list and chain via ``BaseExceptionGroup`` when more than one node contributed. Single-exception case keeps the narrow chain so existing callers continue to work. No-exception case (every node returned no-leader-known) raises with no chain — the message itself is the diagnostic. Mirrors the discipline already applied to ``ConnectionPool.initialize`` partial-failure handling. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5a52405 commit a5fc30c

2 files changed

Lines changed: 106 additions & 4 deletions

File tree

src/dqliteclient/cluster.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,15 @@ async def _find_leader_impl(self, *, trust_server_heartbeat: bool) -> str:
327327
nodes.sort(key=lambda n: 0 if n.role == NodeRole.VOTER else 1)
328328

329329
errors: list[str] = []
330-
last_exc: BaseException | None = None
330+
# Collect every per-node BaseException so the final
331+
# ``ClusterError`` can chain them all via
332+
# ``BaseExceptionGroup``. Previously only the LAST iteration's
333+
# exception was preserved on ``__cause__`` — code that
334+
# branches on the cause class (e.g. routing security alerts
335+
# for ``ProtocolError``-from-malformed-redirect) saw a non-
336+
# deterministic decision based on iteration ordering. Mirrors
337+
# the discipline already applied to ``ConnectionPool.initialize``.
338+
per_node_excs: list[BaseException] = []
331339
total_nodes = len(nodes)
332340

333341
for idx, node in enumerate(nodes):
@@ -386,7 +394,7 @@ async def _find_leader_impl(self, *, trust_server_heartbeat: bool) -> str:
386394
total_nodes,
387395
)
388396
errors.append(f"{_safe_addr}: timed out")
389-
last_exc = e
397+
per_node_excs.append(e)
390398
continue
391399
except (DqliteConnectionError, ProtocolError, OperationalError, OSError) as e:
392400
# Narrow the catch so programming bugs (TypeError, KeyError,
@@ -402,7 +410,7 @@ async def _find_leader_impl(self, *, trust_server_heartbeat: bool) -> str:
402410
total_nodes,
403411
)
404412
errors.append(f"{_sanitize_display_text(node.address)}: {_truncate_error(str(e))}")
405-
last_exc = e
413+
per_node_excs.append(e)
406414
continue
407415

408416
joined = "; ".join(errors)
@@ -411,7 +419,21 @@ async def _find_leader_impl(self, *, trust_server_heartbeat: bool) -> str:
411419
joined = (
412420
joined[:_MAX_AGGREGATE_ERROR_PAYLOAD] + f"... [aggregate truncated, {kept} chars]"
413421
)
414-
raise ClusterError(f"Could not find leader. Errors: {joined}") from last_exc
422+
# Chain via ``BaseExceptionGroup`` when more than one node
423+
# contributed a real exception (the no-leader-known arm
424+
# produces no exception, only an entry in ``errors``). Single-
425+
# exception case keeps the narrow chain so existing callers
426+
# that branch on ``e.__cause__`` type continue to work.
427+
# No-exception case (every node returned no-leader-known)
428+
# raises with no chain — the message itself is the
429+
# diagnostic. Mirrors ``ConnectionPool.initialize``'s discipline.
430+
if len(per_node_excs) > 1:
431+
raise ClusterError(f"Could not find leader. Errors: {joined}") from BaseExceptionGroup(
432+
"find_leader: per-node failures", per_node_excs
433+
)
434+
if per_node_excs:
435+
raise ClusterError(f"Could not find leader. Errors: {joined}") from per_node_excs[0]
436+
raise ClusterError(f"Could not find leader. Errors: {joined}")
415437

416438
async def _query_leader(
417439
self, address: str, *, trust_server_heartbeat: bool = False
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Pin: ``ClusterClient._find_leader_impl`` chains per-node
2+
failures via ``BaseExceptionGroup`` when more than one node
3+
contributed a real exception.
4+
5+
Pre-fix, only the LAST iteration's exception was preserved on
6+
``__cause__`` — code that branches on the cause class (e.g.
7+
routing security alerts for ``ProtocolError``-from-malformed-
8+
redirect) saw a non-deterministic decision based on iteration
9+
ordering.
10+
11+
Mirrors ``ConnectionPool.initialize``'s
12+
``BaseExceptionGroup``-on-multi-failure discipline.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
from unittest.mock import AsyncMock
18+
19+
import pytest
20+
21+
from dqliteclient.cluster import ClusterClient
22+
from dqliteclient.exceptions import (
23+
ClusterError,
24+
DqliteConnectionError,
25+
ProtocolError,
26+
)
27+
from dqliteclient.node_store import MemoryNodeStore
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_find_leader_aggregate_chains_via_exception_group_when_multiple_failures() -> None:
32+
store = MemoryNodeStore(["node-a:9001", "node-b:9001", "node-c:9001"])
33+
cluster = ClusterClient(store, timeout=0.5)
34+
35+
# node-a → ProtocolError (security-relevant)
36+
# node-b → DqliteConnectionError (transport)
37+
# node-c → TimeoutError
38+
async def _query_leader_per_node(address: str, **_kw: object) -> None:
39+
if address == "node-a:9001":
40+
raise ProtocolError("malformed redirect from node-a")
41+
if address == "node-b:9001":
42+
raise DqliteConnectionError("connection refused on node-b")
43+
raise TimeoutError("node-c timed out")
44+
45+
cluster._query_leader = AsyncMock(side_effect=_query_leader_per_node)
46+
47+
with pytest.raises(ClusterError) as exc_info:
48+
await cluster.find_leader()
49+
50+
cause = exc_info.value.__cause__
51+
# Multi-failure case: cause must be BaseExceptionGroup so callers
52+
# branching on .split(ProtocolError) recover the security-relevant
53+
# exception regardless of iteration order.
54+
assert isinstance(cause, BaseExceptionGroup), (
55+
f"expected BaseExceptionGroup chain on multi-failure, got {type(cause).__name__}"
56+
)
57+
matched, _ = cause.split(ProtocolError)
58+
assert matched is not None and len(matched.exceptions) == 1
59+
matched_transport, _ = cause.split(DqliteConnectionError)
60+
assert matched_transport is not None and len(matched_transport.exceptions) == 1
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_find_leader_aggregate_keeps_narrow_chain_on_single_failure() -> None:
65+
"""Backward-compat: single-failure case keeps the narrow chain so
66+
callers that branch on ``e.__cause__`` type continue to work."""
67+
store = MemoryNodeStore(["node-a:9001"])
68+
cluster = ClusterClient(store, timeout=0.5)
69+
70+
async def _query_leader_fails(address: str, **_kw: object) -> None:
71+
raise ProtocolError("malformed redirect")
72+
73+
cluster._query_leader = AsyncMock(side_effect=_query_leader_fails)
74+
75+
with pytest.raises(ClusterError) as exc_info:
76+
await cluster.find_leader()
77+
78+
cause = exc_info.value.__cause__
79+
assert isinstance(cause, ProtocolError)
80+
assert "malformed redirect" in str(cause)

0 commit comments

Comments
 (0)