Skip to content

Commit 2efce12

Browse files
Guard decode_handshake() against double-call that silently eats message bytes
A second decode_handshake() call would consume 8 bytes of actual message data from the buffer, interpreting them as a version number. Added a guard that raises ProtocolError if the handshake was already completed. Also fixed tests that incorrectly called decode_handshake() on client-side decoders. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7baf09b commit 2efce12

2 files changed

Lines changed: 52 additions & 23 deletions

File tree

src/dqlitewire/codec.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,11 @@ def decode_handshake(self) -> int | None:
221221
222222
Returns the protocol version or None if not enough data.
223223
Must be called before decode() on request decoders.
224-
Raises ProtocolError if the version is not recognized.
224+
Raises ProtocolError if the version is not recognized or if
225+
the handshake was already completed.
225226
"""
227+
if self._handshake_done:
228+
raise ProtocolError("Handshake already completed")
226229
data = self._buffer.read_bytes(8)
227230
if data is None:
228231
return None

tests/test_codec.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def test_encode_message(self) -> None:
6363

6464
class TestMessageDecoder:
6565
def test_decode_handshake(self) -> None:
66-
decoder = MessageDecoder()
66+
decoder = MessageDecoder(is_request=True)
6767
decoder.feed(PROTOCOL_VERSION.to_bytes(8, "little"))
6868
version = decoder.decode_handshake()
6969
assert version == PROTOCOL_VERSION
7070

7171
def test_decode_handshake_partial(self) -> None:
72-
decoder = MessageDecoder()
72+
decoder = MessageDecoder(is_request=True)
7373
decoder.feed(b"\x01\x00\x00")
7474
version = decoder.decode_handshake()
7575
assert version is None
@@ -274,6 +274,41 @@ def test_request_decoder_decode_bytes_rejects_before_handshake(self) -> None:
274274
with pytest.raises(ProtocolError, match="[Hh]andshake"):
275275
decoder.decode_bytes(msg.encode())
276276

277+
def test_double_decode_handshake_raises(self) -> None:
278+
"""Calling decode_handshake() twice must raise ProtocolError.
279+
280+
A second call would consume 8 bytes of actual message data from the
281+
buffer and interpret them as a version number, silently corrupting
282+
the stream.
283+
"""
284+
from dqlitewire.exceptions import ProtocolError
285+
286+
decoder = MessageDecoder(is_request=True)
287+
handshake = PROTOCOL_VERSION.to_bytes(8, "little")
288+
msg = LeaderRequest()
289+
decoder.feed(handshake + msg.encode())
290+
291+
# First handshake succeeds
292+
version = decoder.decode_handshake()
293+
assert version == PROTOCOL_VERSION
294+
295+
# Second handshake must raise, not consume message bytes
296+
with pytest.raises(ProtocolError, match="[Hh]andshake already completed"):
297+
decoder.decode_handshake()
298+
299+
# The message should still be decodable (not consumed by second handshake)
300+
decoded = decoder.decode()
301+
assert isinstance(decoded, LeaderRequest)
302+
303+
def test_decode_handshake_on_client_decoder_raises(self) -> None:
304+
"""Client-side decoder should reject decode_handshake() since _handshake_done=True."""
305+
from dqlitewire.exceptions import ProtocolError
306+
307+
decoder = MessageDecoder(is_request=False)
308+
decoder.feed(PROTOCOL_VERSION.to_bytes(8, "little"))
309+
with pytest.raises(ProtocolError, match="[Hh]andshake already completed"):
310+
decoder.decode_handshake()
311+
277312
def test_response_decoder_decode_bytes_works_without_handshake(self) -> None:
278313
"""decode_bytes() on response decoders should work without handshake."""
279314
decoder = MessageDecoder(is_request=False)
@@ -292,15 +327,11 @@ def test_response_decoder_allows_decode_without_handshake(self) -> None:
292327
assert isinstance(decoded, LeaderResponse)
293328

294329
def test_legacy_handshake_decodes_leader_response_as_legacy(self) -> None:
295-
"""After legacy handshake, LeaderResponse should use legacy format."""
296-
from dqlitewire.types import encode_text
297-
298-
decoder = MessageDecoder(is_request=False)
299-
# Feed legacy handshake + legacy LeaderResponse (address-only, no node_id)
300-
handshake = PROTOCOL_VERSION_LEGACY.to_bytes(8, "little")
301-
# Build a LeaderResponse with legacy body: just text address
330+
"""Legacy version should decode LeaderResponse in legacy format."""
331+
from dqlitewire.codec import decode_message
302332
from dqlitewire.constants import ResponseType
303333
from dqlitewire.messages.base import Header
334+
from dqlitewire.types import encode_text
304335

305336
address = "192.168.1.1:9001"
306337
body = encode_text(address)
@@ -309,32 +340,27 @@ def test_legacy_handshake_decodes_leader_response_as_legacy(self) -> None:
309340
msg_type=ResponseType.LEADER,
310341
schema=0,
311342
)
312-
decoder.feed(handshake + header.encode() + body)
313-
version = decoder.decode_handshake()
314-
assert version == PROTOCOL_VERSION_LEGACY
343+
data = header.encode() + body
315344

316-
decoded = decoder.decode()
345+
decoded = decode_message(data, is_request=False, version=PROTOCOL_VERSION_LEGACY)
317346
assert isinstance(decoded, LeaderResponse)
318347
assert decoded.node_id == 0
319348
assert decoded.address == address
320349

321350
def test_modern_handshake_decodes_leader_response_as_modern(self) -> None:
322-
"""After modern handshake, LeaderResponse should use modern format."""
323-
decoder = MessageDecoder(is_request=False)
324-
handshake = PROTOCOL_VERSION.to_bytes(8, "little")
351+
"""Modern version should decode LeaderResponse in modern format."""
352+
from dqlitewire.codec import decode_message
353+
325354
msg = LeaderResponse(node_id=42, address="node1:9001")
326-
decoder.feed(handshake + msg.encode())
327-
version = decoder.decode_handshake()
328-
assert version == PROTOCOL_VERSION
329355

330-
decoded = decoder.decode()
356+
decoded = decode_message(msg.encode(), is_request=False, version=PROTOCOL_VERSION)
331357
assert isinstance(decoded, LeaderResponse)
332358
assert decoded.node_id == 42
333359
assert decoded.address == "node1:9001"
334360

335361
def test_decoder_version_property(self) -> None:
336-
"""Decoder should expose the protocol version after handshake."""
337-
decoder = MessageDecoder(is_request=False)
362+
"""Request decoder should expose the protocol version after handshake."""
363+
decoder = MessageDecoder(is_request=True)
338364
assert decoder.version is None
339365
decoder.feed(PROTOCOL_VERSION.to_bytes(8, "little"))
340366
decoder.decode_handshake()

0 commit comments

Comments
 (0)