@@ -52,10 +52,12 @@ async def test_find_leader_success(self) -> None:
5252
5353 from dqlitewire .messages import LeaderResponse , WelcomeResponse
5454
55- # First call for handshake, second for leader query
55+ # Upstream raft_leader sets id and address atomically: a voter
56+ # that IS the leader returns its own id AND its own address
57+ # (never (nonzero, "")).
5658 responses = [
5759 WelcomeResponse (heartbeat_timeout = 15000 ).encode (),
58- LeaderResponse (node_id = 1 , address = "" ).encode (), # Empty = this node is leader
60+ LeaderResponse (node_id = 1 , address = "localhost:9001 " ).encode (),
5961 ]
6062 mock_reader .read .side_effect = responses
6163
@@ -217,7 +219,7 @@ async def test_find_leader_skips_node_with_bad_handshake(self) -> None:
217219 responses = [
218220 b"\x00 " * 64 ,
219221 WelcomeResponse (heartbeat_timeout = 15000 ).encode (),
220- LeaderResponse (node_id = 2 , address = "" ).encode (),
222+ LeaderResponse (node_id = 2 , address = "localhost:9002 " ).encode (),
221223 ]
222224 mock_reader .read .side_effect = responses
223225
@@ -251,7 +253,7 @@ async def hang_forever():
251253
252254 responses = [
253255 WelcomeResponse (heartbeat_timeout = 15000 ).encode (),
254- LeaderResponse (node_id = 1 , address = "" ).encode (),
256+ LeaderResponse (node_id = 1 , address = "localhost:9001 " ).encode (),
255257 ]
256258 mock_reader .read .side_effect = responses
257259
@@ -276,7 +278,7 @@ async def test_find_leader_propagates_programming_bugs(self) -> None:
276278 store = MemoryNodeStore (["localhost:9001" , "localhost:9002" ])
277279 client = ClusterClient (store , timeout = 0.5 )
278280
279- async def buggy_query (_address : str ) -> str | None :
281+ async def buggy_query (_address : str , ** _kwargs : object ) -> str | None :
280282 raise TypeError ("programmer mistake" )
281283
282284 with (
@@ -297,7 +299,7 @@ async def test_find_leader_transport_error_chains_cause(self) -> None:
297299
298300 boom = DqliteConnectionError ("handshake failed" )
299301
300- async def failing_query (_address : str ) -> str | None :
302+ async def failing_query (_address : str , ** _kwargs : object ) -> str | None :
301303 raise boom
302304
303305 with (
@@ -318,7 +320,7 @@ async def test_find_leader_randomizes_node_order(self) -> None:
318320
319321 first_probed : list [str ] = []
320322
321- async def track (address : str ) -> str | None :
323+ async def track (address : str , ** _kwargs : object ) -> str | None :
322324 first_probed .append (address )
323325 raise DqliteConnectionError ("not leader" )
324326
@@ -350,7 +352,7 @@ async def test_find_leader_probes_voters_before_non_voters(self) -> None:
350352
351353 order : list [str ] = []
352354
353- async def track (address : str ) -> str | None :
355+ async def track (address : str , ** _kwargs : object ) -> str | None :
354356 order .append (address )
355357 return None # no leader known — keep probing
356358
@@ -383,7 +385,7 @@ async def test_connect_does_not_retry_plain_sql_errors(self) -> None:
383385
384386 call_count = 0
385387
386- async def always_sql_error () -> str :
388+ async def always_sql_error (** _kwargs : object ) -> str :
387389 nonlocal call_count
388390 call_count += 1
389391 raise OperationalError (1 , "some sql error" )
@@ -431,7 +433,7 @@ async def test_max_attempts_override_honored(self) -> None:
431433
432434 call_count = [0 ]
433435
434- async def fake_find_leader () -> str :
436+ async def fake_find_leader (** _kwargs : object ) -> str :
435437 call_count [0 ] += 1
436438 raise DqliteConnectionError ("unreachable" )
437439
@@ -457,7 +459,7 @@ async def test_failed_attempts_logged(self, caplog: pytest.LogCaptureFixture) ->
457459 store = MemoryNodeStore (["localhost:1" ]) # unreachable
458460 client = ClusterClient (store , timeout = 0.1 )
459461
460- async def fake_find_leader () -> str :
462+ async def fake_find_leader (** _kwargs : object ) -> str :
461463 raise DqliteConnectionError ("simulated" )
462464
463465 client .find_leader = fake_find_leader # type: ignore[method-assign]
@@ -472,3 +474,68 @@ async def fake_find_leader() -> str:
472474 f"Expected 2 per-attempt log lines, got { len (attempt_logs )} : "
473475 f"{ [r .message for r in attempt_logs ]} "
474476 )
477+
478+
479+ class TestQueryLeaderTrustsHeartbeat :
480+ """_query_leader forwards the trust_server_heartbeat flag."""
481+
482+ async def test_flag_propagates_to_probe_protocol (self ) -> None :
483+ store = MemoryNodeStore (["localhost:9001" ])
484+ client = ClusterClient (store , timeout = 1.0 )
485+
486+ mock_reader = AsyncMock ()
487+ mock_writer = MagicMock ()
488+ mock_writer .drain = AsyncMock ()
489+ mock_writer .close = MagicMock ()
490+ mock_writer .wait_closed = AsyncMock ()
491+
492+ captured : dict [str , object ] = {}
493+
494+ class FakeProto :
495+ def __init__ (self , * args : object , ** kwargs : object ) -> None :
496+ captured .update (kwargs )
497+
498+ async def handshake (self ) -> None :
499+ pass
500+
501+ async def get_leader (self ) -> tuple [int , str ]:
502+ return (1 , "localhost:9001" )
503+
504+ with (
505+ patch ("asyncio.open_connection" , return_value = (mock_reader , mock_writer )),
506+ patch ("dqliteclient.cluster.DqliteProtocol" , FakeProto ),
507+ ):
508+ await client ._query_leader ("localhost:9001" , trust_server_heartbeat = True )
509+
510+ assert captured .get ("trust_server_heartbeat" ) is True
511+
512+ async def test_flag_default_false (self ) -> None :
513+ store = MemoryNodeStore (["localhost:9001" ])
514+ client = ClusterClient (store , timeout = 1.0 )
515+
516+ mock_reader = AsyncMock ()
517+ mock_writer = MagicMock ()
518+ mock_writer .drain = AsyncMock ()
519+ mock_writer .close = MagicMock ()
520+ mock_writer .wait_closed = AsyncMock ()
521+
522+ captured : dict [str , object ] = {}
523+
524+ class FakeProto :
525+ def __init__ (self , * args : object , ** kwargs : object ) -> None :
526+ captured .update (kwargs )
527+
528+ async def handshake (self ) -> None :
529+ pass
530+
531+ async def get_leader (self ) -> tuple [int , str ]:
532+ return (1 , "localhost:9001" )
533+
534+ with (
535+ patch ("asyncio.open_connection" , return_value = (mock_reader , mock_writer )),
536+ patch ("dqliteclient.cluster.DqliteProtocol" , FakeProto ),
537+ ):
538+ await client ._query_leader ("localhost:9001" )
539+
540+ assert captured .get ("trust_server_heartbeat" ) is False
541+
0 commit comments