Skip to content

Commit 2d19e03

Browse files
fix: peek-before-consume in decode_handshake()
Previously the handshake path read 8 bytes BEFORE validating the version, so an unsupported version left the buffer advanced by 8 with _handshake_done still False. A retry consumed the NEXT 8 bytes as a "version" — almost always the header of the first real message — and silently desynchronized the stream. Add ReadBuffer.peek_bytes(n) for non-consuming reads, then rewrite decode_handshake() to peek the version bytes, validate, and only call read_bytes(8) on success. An invalid version now leaves the bytes in place: a retry is deterministic (same bytes, same error) instead of silently advancing into real message data. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d21414f commit 2d19e03

3 files changed

Lines changed: 128 additions & 3 deletions

File tree

src/dqlitewire/buffer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,18 @@ def read_bytes(self, n: int) -> bytes | None:
218218
self._maybe_compact()
219219
return data
220220

221+
def peek_bytes(self, n: int) -> bytes | None:
222+
"""Return the next n bytes without advancing the read position.
223+
224+
Returns None if fewer than n bytes are available. Symmetric with
225+
``read_bytes(n)`` but non-consuming — use this when you need to
226+
validate the bytes before deciding whether to consume them.
227+
"""
228+
available = len(self._data) - self._pos
229+
if available < n:
230+
return None
231+
return bytes(self._data[self._pos : self._pos + n])
232+
221233
def _maybe_compact(self) -> None:
222234
"""Compact buffer if we've consumed a lot."""
223235
if self._pos > 4096:

src/dqlitewire/codec.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,15 +333,24 @@ def decode_handshake(self) -> int | None:
333333
Must be called before decode() on request decoders.
334334
Raises ProtocolError if the version is not recognized or if
335335
the handshake was already completed.
336+
337+
On an unsupported version, the 8 handshake bytes are left in the
338+
buffer untouched so that a retry is deterministic (same bytes, same
339+
error) rather than silently consuming the next 8 bytes of real data.
336340
"""
337341
if self._handshake_done:
338342
raise ProtocolError("Handshake already completed")
339-
data = self._buffer.read_bytes(8)
340-
if data is None:
343+
# Peek first so we only commit on a valid version. An invalid version
344+
# leaves the bytes in place — a retry is deterministic rather than
345+
# silently advancing into real message data.
346+
peek = self._buffer.peek_bytes(8)
347+
if peek is None:
341348
return None
342-
version = int.from_bytes(data, "little")
349+
version = int.from_bytes(peek, "little")
343350
if version not in _SUPPORTED_VERSIONS:
344351
raise ProtocolError(f"Unsupported protocol version: {version:#x}")
352+
# Valid — commit by advancing past the handshake bytes.
353+
self._buffer.read_bytes(8)
345354
self._version = version
346355
self._handshake_done = True
347356
return version

tests/test_codec.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,110 @@ def test_decode_handshake_rejects_unknown_version(self) -> None:
221221
with pytest.raises(ProtocolError, match="[Uu]nsupported protocol version"):
222222
decoder.decode_handshake()
223223

224+
def test_decode_handshake_failure_does_not_consume_bytes(self) -> None:
225+
"""On an unsupported version, decode_handshake() must NOT consume the
226+
handshake bytes.
227+
228+
Previously the method read 8 bytes BEFORE validating the version, so
229+
a failed handshake left the buffer advanced by 8 with ``_handshake_done``
230+
still False. A retry consumed the next 8 bytes as a "version", which
231+
was almost always the header of a real message — silently desynchronizing
232+
the stream. Peek-before-consume means the bytes stay in the buffer and
233+
a retry is deterministic: same bytes, same error.
234+
"""
235+
from dqlitewire.exceptions import ProtocolError
236+
237+
decoder = MessageDecoder(is_request=True)
238+
bogus = b"\x42" * 8
239+
decoder.feed(bogus)
240+
241+
with pytest.raises(ProtocolError, match="[Uu]nsupported protocol version"):
242+
decoder.decode_handshake()
243+
244+
# The 8 handshake bytes must still be in the buffer.
245+
assert decoder._buffer.available() == 8
246+
assert not decoder._handshake_done
247+
248+
# A retry on the same bytes gets the same error, deterministically.
249+
with pytest.raises(ProtocolError, match="[Uu]nsupported protocol version"):
250+
decoder.decode_handshake()
251+
assert decoder._buffer.available() == 8
252+
253+
def test_decode_handshake_partial_data_leaves_bytes_intact(self) -> None:
254+
"""With fewer than 8 bytes buffered, decode_handshake() returns None
255+
and leaves the partial data untouched so a subsequent feed() can
256+
complete the handshake.
257+
"""
258+
decoder = MessageDecoder(is_request=True)
259+
version_bytes = PROTOCOL_VERSION_LEGACY.to_bytes(8, "little")
260+
decoder.feed(version_bytes[:4])
261+
assert decoder.decode_handshake() is None
262+
assert decoder._buffer.available() == 4
263+
264+
decoder.feed(version_bytes[4:])
265+
assert decoder.decode_handshake() == PROTOCOL_VERSION_LEGACY
266+
267+
def test_decode_handshake_failure_preserves_following_bytes(self) -> None:
268+
"""Direct reproducer of the original bug: a buffer containing a bogus
269+
version followed by a real valid version. The pre-fix code consumed
270+
both 8-byte chunks across two retries; the fix must consume neither
271+
on the first failure.
272+
"""
273+
from dqlitewire.exceptions import ProtocolError
274+
275+
decoder = MessageDecoder(is_request=True)
276+
bogus = b"\x42" * 8
277+
valid = PROTOCOL_VERSION_LEGACY.to_bytes(8, "little")
278+
decoder.feed(bogus + valid)
279+
280+
with pytest.raises(ProtocolError, match="[Uu]nsupported protocol version"):
281+
decoder.decode_handshake()
282+
# All 16 bytes still in the buffer — neither chunk has been consumed.
283+
assert decoder._buffer.available() == 16
284+
285+
# Even after a retry, the valid bytes behind are untouched.
286+
with pytest.raises(ProtocolError, match="[Uu]nsupported protocol version"):
287+
decoder.decode_handshake()
288+
assert decoder._buffer.available() == 16
289+
290+
def test_decode_handshake_recoverable_via_reset(self) -> None:
291+
"""After a handshake failure, reset() clears the buffer and the
292+
decoder accepts a fresh handshake on a reconnect.
293+
"""
294+
from dqlitewire.exceptions import ProtocolError
295+
296+
decoder = MessageDecoder(is_request=True)
297+
decoder.feed(b"\x42" * 8)
298+
with pytest.raises(ProtocolError, match="[Uu]nsupported protocol version"):
299+
decoder.decode_handshake()
300+
301+
decoder.reset()
302+
assert decoder._buffer.available() == 0
303+
assert not decoder._handshake_done
304+
305+
# Fresh valid handshake works.
306+
decoder.feed(PROTOCOL_VERSION.to_bytes(8, "little"))
307+
assert decoder.decode_handshake() == PROTOCOL_VERSION
308+
assert decoder._handshake_done
309+
310+
def test_peek_bytes_does_not_advance_position(self) -> None:
311+
"""Sanity: ReadBuffer.peek_bytes() must return the requested bytes
312+
without advancing _pos. This is the primitive decode_handshake()
313+
depends on.
314+
"""
315+
from dqlitewire.buffer import ReadBuffer
316+
317+
buf = ReadBuffer()
318+
buf.feed(b"abcdefghij")
319+
assert buf.peek_bytes(4) == b"abcd"
320+
assert buf.available() == 10
321+
# Repeated peeks return the same bytes.
322+
assert buf.peek_bytes(4) == b"abcd"
323+
assert buf.available() == 10
324+
# Asking for more than available returns None.
325+
assert buf.peek_bytes(20) is None
326+
assert buf.available() == 10
327+
224328
def test_decode_handshake_rejects_zero_version(self) -> None:
225329
"""Version 0 is not a valid protocol version."""
226330
from dqlitewire.exceptions import ProtocolError

0 commit comments

Comments
 (0)