Skip to content

Commit 51a55e9

Browse files
Forward trust_server_heartbeat into the leader-discovery probe
_query_leader constructed a fresh DqliteProtocol without threading the flag through, so operators who opted into a widened heartbeat window for the main query path got the default tight timeout during leader discovery. Thread trust_server_heartbeat through find_leader and _query_leader; connect() forwards the value at call time. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 52b649d commit 51a55e9

2 files changed

Lines changed: 100 additions & 17 deletions

File tree

src/dqliteclient/cluster.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,13 @@ def _check_redirect(self, address: str) -> None:
7878
if not self._redirect_policy(address):
7979
raise ClusterError(f"Leader redirect to {address!r} rejected by redirect_policy")
8080

81-
async def find_leader(self) -> str:
81+
async def find_leader(self, *, trust_server_heartbeat: bool = False) -> str:
8282
"""Find the current cluster leader.
8383
84-
Returns the leader address.
84+
Returns the leader address. ``trust_server_heartbeat`` is forwarded
85+
to each probe protocol so operators who opted into a widened
86+
heartbeat window for the main query path get the same semantics
87+
during leader discovery.
8588
"""
8689
nodes = await self._node_store.get_nodes()
8790

@@ -109,7 +112,11 @@ async def find_leader(self) -> str:
109112
for node in nodes:
110113
try:
111114
leader_address = await asyncio.wait_for(
112-
self._query_leader(node.address), timeout=self._timeout
115+
self._query_leader(
116+
node.address,
117+
trust_server_heartbeat=trust_server_heartbeat,
118+
),
119+
timeout=self._timeout,
113120
)
114121
if leader_address:
115122
# Only leader_address values that did NOT come from
@@ -133,7 +140,9 @@ async def find_leader(self) -> str:
133140

134141
raise ClusterError(f"Could not find leader. Errors: {'; '.join(errors)}") from last_exc
135142

136-
async def _query_leader(self, address: str) -> str | None:
143+
async def _query_leader(
144+
self, address: str, *, trust_server_heartbeat: bool = False
145+
) -> str | None:
137146
"""Query a node for the current leader."""
138147
host, port = _parse_address(address)
139148

@@ -146,7 +155,12 @@ async def _query_leader(self, address: str) -> str | None:
146155
return None
147156

148157
try:
149-
protocol = DqliteProtocol(reader, writer, timeout=self._timeout)
158+
protocol = DqliteProtocol(
159+
reader,
160+
writer,
161+
timeout=self._timeout,
162+
trust_server_heartbeat=trust_server_heartbeat,
163+
)
150164
await protocol.handshake()
151165
node_id, leader_addr = await protocol.get_leader()
152166

@@ -198,7 +212,9 @@ async def try_connect() -> DqliteConnection:
198212
attempt = attempt_counter[0]
199213
leader: str | None = None
200214
try:
201-
leader = await self.find_leader()
215+
leader = await self.find_leader(
216+
trust_server_heartbeat=trust_server_heartbeat,
217+
)
202218
conn = DqliteConnection(
203219
leader,
204220
database=database,

tests/test_cluster.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ async def test_find_leader_success(self) -> None:
5252

5353
from dqlitewire.messages import LeaderResponse, WelcomeResponse
5454

55-
# First call for handshake, second for leader query
55+
# Upstream raft_leader sets id and address atomically: a voter
56+
# that IS the leader returns its own id AND its own address
57+
# (never (nonzero, "")).
5658
responses = [
5759
WelcomeResponse(heartbeat_timeout=15000).encode(),
58-
LeaderResponse(node_id=1, address="").encode(), # Empty = this node is leader
60+
LeaderResponse(node_id=1, address="localhost:9001").encode(),
5961
]
6062
mock_reader.read.side_effect = responses
6163

@@ -217,7 +219,7 @@ async def test_find_leader_skips_node_with_bad_handshake(self) -> None:
217219
responses = [
218220
b"\x00" * 64,
219221
WelcomeResponse(heartbeat_timeout=15000).encode(),
220-
LeaderResponse(node_id=2, address="").encode(),
222+
LeaderResponse(node_id=2, address="localhost:9002").encode(),
221223
]
222224
mock_reader.read.side_effect = responses
223225

@@ -251,7 +253,7 @@ async def hang_forever():
251253

252254
responses = [
253255
WelcomeResponse(heartbeat_timeout=15000).encode(),
254-
LeaderResponse(node_id=1, address="").encode(),
256+
LeaderResponse(node_id=1, address="localhost:9001").encode(),
255257
]
256258
mock_reader.read.side_effect = responses
257259

@@ -276,7 +278,7 @@ async def test_find_leader_propagates_programming_bugs(self) -> None:
276278
store = MemoryNodeStore(["localhost:9001", "localhost:9002"])
277279
client = ClusterClient(store, timeout=0.5)
278280

279-
async def buggy_query(_address: str) -> str | None:
281+
async def buggy_query(_address: str, **_kwargs: object) -> str | None:
280282
raise TypeError("programmer mistake")
281283

282284
with (
@@ -297,7 +299,7 @@ async def test_find_leader_transport_error_chains_cause(self) -> None:
297299

298300
boom = DqliteConnectionError("handshake failed")
299301

300-
async def failing_query(_address: str) -> str | None:
302+
async def failing_query(_address: str, **_kwargs: object) -> str | None:
301303
raise boom
302304

303305
with (
@@ -318,7 +320,7 @@ async def test_find_leader_randomizes_node_order(self) -> None:
318320

319321
first_probed: list[str] = []
320322

321-
async def track(address: str) -> str | None:
323+
async def track(address: str, **_kwargs: object) -> str | None:
322324
first_probed.append(address)
323325
raise DqliteConnectionError("not leader")
324326

@@ -350,7 +352,7 @@ async def test_find_leader_probes_voters_before_non_voters(self) -> None:
350352

351353
order: list[str] = []
352354

353-
async def track(address: str) -> str | None:
355+
async def track(address: str, **_kwargs: object) -> str | None:
354356
order.append(address)
355357
return None # no leader known — keep probing
356358

@@ -383,7 +385,7 @@ async def test_connect_does_not_retry_plain_sql_errors(self) -> None:
383385

384386
call_count = 0
385387

386-
async def always_sql_error() -> str:
388+
async def always_sql_error(**_kwargs: object) -> str:
387389
nonlocal call_count
388390
call_count += 1
389391
raise OperationalError(1, "some sql error")
@@ -431,7 +433,7 @@ async def test_max_attempts_override_honored(self) -> None:
431433

432434
call_count = [0]
433435

434-
async def fake_find_leader() -> str:
436+
async def fake_find_leader(**_kwargs: object) -> str:
435437
call_count[0] += 1
436438
raise DqliteConnectionError("unreachable")
437439

@@ -457,7 +459,7 @@ async def test_failed_attempts_logged(self, caplog: pytest.LogCaptureFixture) ->
457459
store = MemoryNodeStore(["localhost:1"]) # unreachable
458460
client = ClusterClient(store, timeout=0.1)
459461

460-
async def fake_find_leader() -> str:
462+
async def fake_find_leader(**_kwargs: object) -> str:
461463
raise DqliteConnectionError("simulated")
462464

463465
client.find_leader = fake_find_leader # type: ignore[method-assign]
@@ -472,3 +474,68 @@ async def fake_find_leader() -> str:
472474
f"Expected 2 per-attempt log lines, got {len(attempt_logs)}: "
473475
f"{[r.message for r in attempt_logs]}"
474476
)
477+
478+
479+
class TestQueryLeaderTrustsHeartbeat:
480+
"""_query_leader forwards the trust_server_heartbeat flag."""
481+
482+
async def test_flag_propagates_to_probe_protocol(self) -> None:
483+
store = MemoryNodeStore(["localhost:9001"])
484+
client = ClusterClient(store, timeout=1.0)
485+
486+
mock_reader = AsyncMock()
487+
mock_writer = MagicMock()
488+
mock_writer.drain = AsyncMock()
489+
mock_writer.close = MagicMock()
490+
mock_writer.wait_closed = AsyncMock()
491+
492+
captured: dict[str, object] = {}
493+
494+
class FakeProto:
495+
def __init__(self, *args: object, **kwargs: object) -> None:
496+
captured.update(kwargs)
497+
498+
async def handshake(self) -> None:
499+
pass
500+
501+
async def get_leader(self) -> tuple[int, str]:
502+
return (1, "localhost:9001")
503+
504+
with (
505+
patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)),
506+
patch("dqliteclient.cluster.DqliteProtocol", FakeProto),
507+
):
508+
await client._query_leader("localhost:9001", trust_server_heartbeat=True)
509+
510+
assert captured.get("trust_server_heartbeat") is True
511+
512+
async def test_flag_default_false(self) -> None:
513+
store = MemoryNodeStore(["localhost:9001"])
514+
client = ClusterClient(store, timeout=1.0)
515+
516+
mock_reader = AsyncMock()
517+
mock_writer = MagicMock()
518+
mock_writer.drain = AsyncMock()
519+
mock_writer.close = MagicMock()
520+
mock_writer.wait_closed = AsyncMock()
521+
522+
captured: dict[str, object] = {}
523+
524+
class FakeProto:
525+
def __init__(self, *args: object, **kwargs: object) -> None:
526+
captured.update(kwargs)
527+
528+
async def handshake(self) -> None:
529+
pass
530+
531+
async def get_leader(self) -> tuple[int, str]:
532+
return (1, "localhost:9001")
533+
534+
with (
535+
patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)),
536+
patch("dqliteclient.cluster.DqliteProtocol", FakeProto),
537+
):
538+
await client._query_leader("localhost:9001")
539+
540+
assert captured.get("trust_server_heartbeat") is False
541+

0 commit comments

Comments
 (0)