@@ -63,13 +63,13 @@ def test_encode_message(self) -> None:
6363
6464class TestMessageDecoder :
6565 def test_decode_handshake (self ) -> None :
66- decoder = MessageDecoder ()
66+ decoder = MessageDecoder (is_request = True )
6767 decoder .feed (PROTOCOL_VERSION .to_bytes (8 , "little" ))
6868 version = decoder .decode_handshake ()
6969 assert version == PROTOCOL_VERSION
7070
7171 def test_decode_handshake_partial (self ) -> None :
72- decoder = MessageDecoder ()
72+ decoder = MessageDecoder (is_request = True )
7373 decoder .feed (b"\x01 \x00 \x00 " )
7474 version = decoder .decode_handshake ()
7575 assert version is None
@@ -274,6 +274,41 @@ def test_request_decoder_decode_bytes_rejects_before_handshake(self) -> None:
274274 with pytest .raises (ProtocolError , match = "[Hh]andshake" ):
275275 decoder .decode_bytes (msg .encode ())
276276
277+ def test_double_decode_handshake_raises (self ) -> None :
278+ """Calling decode_handshake() twice must raise ProtocolError.
279+
280+ A second call would consume 8 bytes of actual message data from the
281+ buffer and interpret them as a version number, silently corrupting
282+ the stream.
283+ """
284+ from dqlitewire .exceptions import ProtocolError
285+
286+ decoder = MessageDecoder (is_request = True )
287+ handshake = PROTOCOL_VERSION .to_bytes (8 , "little" )
288+ msg = LeaderRequest ()
289+ decoder .feed (handshake + msg .encode ())
290+
291+ # First handshake succeeds
292+ version = decoder .decode_handshake ()
293+ assert version == PROTOCOL_VERSION
294+
295+ # Second handshake must raise, not consume message bytes
296+ with pytest .raises (ProtocolError , match = "[Hh]andshake already completed" ):
297+ decoder .decode_handshake ()
298+
299+ # The message should still be decodable (not consumed by second handshake)
300+ decoded = decoder .decode ()
301+ assert isinstance (decoded , LeaderRequest )
302+
303+ def test_decode_handshake_on_client_decoder_raises (self ) -> None :
304+ """Client-side decoder should reject decode_handshake() since _handshake_done=True."""
305+ from dqlitewire .exceptions import ProtocolError
306+
307+ decoder = MessageDecoder (is_request = False )
308+ decoder .feed (PROTOCOL_VERSION .to_bytes (8 , "little" ))
309+ with pytest .raises (ProtocolError , match = "[Hh]andshake already completed" ):
310+ decoder .decode_handshake ()
311+
277312 def test_response_decoder_decode_bytes_works_without_handshake (self ) -> None :
278313 """decode_bytes() on response decoders should work without handshake."""
279314 decoder = MessageDecoder (is_request = False )
@@ -292,15 +327,11 @@ def test_response_decoder_allows_decode_without_handshake(self) -> None:
292327 assert isinstance (decoded , LeaderResponse )
293328
294329 def test_legacy_handshake_decodes_leader_response_as_legacy (self ) -> None :
295- """After legacy handshake, LeaderResponse should use legacy format."""
296- from dqlitewire .types import encode_text
297-
298- decoder = MessageDecoder (is_request = False )
299- # Feed legacy handshake + legacy LeaderResponse (address-only, no node_id)
300- handshake = PROTOCOL_VERSION_LEGACY .to_bytes (8 , "little" )
301- # Build a LeaderResponse with legacy body: just text address
330+ """Legacy version should decode LeaderResponse in legacy format."""
331+ from dqlitewire .codec import decode_message
302332 from dqlitewire .constants import ResponseType
303333 from dqlitewire .messages .base import Header
334+ from dqlitewire .types import encode_text
304335
305336 address = "192.168.1.1:9001"
306337 body = encode_text (address )
@@ -309,32 +340,27 @@ def test_legacy_handshake_decodes_leader_response_as_legacy(self) -> None:
309340 msg_type = ResponseType .LEADER ,
310341 schema = 0 ,
311342 )
312- decoder .feed (handshake + header .encode () + body )
313- version = decoder .decode_handshake ()
314- assert version == PROTOCOL_VERSION_LEGACY
343+ data = header .encode () + body
315344
316- decoded = decoder . decode ( )
345+ decoded = decode_message ( data , is_request = False , version = PROTOCOL_VERSION_LEGACY )
317346 assert isinstance (decoded , LeaderResponse )
318347 assert decoded .node_id == 0
319348 assert decoded .address == address
320349
321350 def test_modern_handshake_decodes_leader_response_as_modern (self ) -> None :
322- """After modern handshake, LeaderResponse should use modern format."""
323- decoder = MessageDecoder ( is_request = False )
324- handshake = PROTOCOL_VERSION . to_bytes ( 8 , "little" )
351+ """Modern version should decode LeaderResponse in modern format."""
352+ from dqlitewire . codec import decode_message
353+
325354 msg = LeaderResponse (node_id = 42 , address = "node1:9001" )
326- decoder .feed (handshake + msg .encode ())
327- version = decoder .decode_handshake ()
328- assert version == PROTOCOL_VERSION
329355
330- decoded = decoder . decode ( )
356+ decoded = decode_message ( msg . encode (), is_request = False , version = PROTOCOL_VERSION )
331357 assert isinstance (decoded , LeaderResponse )
332358 assert decoded .node_id == 42
333359 assert decoded .address == "node1:9001"
334360
335361 def test_decoder_version_property (self ) -> None :
336- """Decoder should expose the protocol version after handshake."""
337- decoder = MessageDecoder (is_request = False )
362+ """Request decoder should expose the protocol version after handshake."""
363+ decoder = MessageDecoder (is_request = True )
338364 assert decoder .version is None
339365 decoder .feed (PROTOCOL_VERSION .to_bytes (8 , "little" ))
340366 decoder .decode_handshake ()
0 commit comments