Skip to content

Commit 94da6ad

Browse files
fix: reorder decode_handshake commit to survive async exceptions
decode_handshake used to commit via three sequential statements: self._buffer.read_bytes(8) # (1) consume bytes self._version = version # (2) record version self._handshake_done = True # (3) mark done A KeyboardInterrupt (or any other PyErr_SetAsyncExc delivery) landing between (1) and (3) left the buffer with the 8 handshake bytes consumed but _handshake_done still False. On retry, decode_handshake re-peeked 8 bytes — but those were now the first 8 bytes of the NEXT message, which almost always fail the _SUPPORTED_VERSIONS check. The caller then saw a misleading "Unsupported protocol version" error pointing at what was actually valid message data. Reorder: commit _version/_handshake_done FIRST, then consume. If read_bytes(8) raises for any BaseException subclass, revert the state commit so the buffer is still coherent (8 bytes still available) and a retry repeats the peek/validate/commit cycle cleanly. Does NOT fix hazard 1 from issue 041 (two-thread TOCTOU) — that remains issue 021's single-owner contract territory. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c1f5794 commit 94da6ad

2 files changed

Lines changed: 153 additions & 9 deletions

File tree

src/dqlitewire/codec.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,19 @@ def decode_handshake(self) -> int | None:
358358
On an unsupported version, the 8 handshake bytes are left in the
359359
buffer untouched so that a retry is deterministic (same bytes, same
360360
error) rather than silently consuming the next 8 bytes of real data.
361+
362+
Signal-safety (issue 041): the commit order is
363+
``_version``/``_handshake_done`` FIRST, then ``read_bytes(8)``.
364+
If an async exception (``KeyboardInterrupt``) lands between the
365+
state commit and the buffer consume, the except block reverts
366+
``_handshake_done`` and ``_version`` so the buffer is still
367+
coherent: the 8 handshake bytes are still there and a retry
368+
repeats the peek/validate/commit cycle. This replaces the
369+
previous "consume then commit" order, which allowed a signal
370+
to leave the bytes consumed but the state not yet marked —
371+
retry would then re-peek 8 bytes of real message data as a
372+
handshake and almost always raise a misleading "Unsupported
373+
protocol version" error.
361374
"""
362375
if self._handshake_done:
363376
raise ProtocolError("Handshake already completed")
@@ -370,10 +383,17 @@ def decode_handshake(self) -> int | None:
370383
version = int.from_bytes(peek, "little")
371384
if version not in _SUPPORTED_VERSIONS:
372385
raise ProtocolError(f"Unsupported protocol version: {version:#x}")
373-
# Valid — commit by advancing past the handshake bytes.
374-
self._buffer.read_bytes(8)
386+
# Commit state BEFORE consuming bytes. If the consume is
387+
# interrupted by an async exception, revert so the peek/commit
388+
# pair becomes atomic from the caller's perspective.
375389
self._version = version
376390
self._handshake_done = True
391+
try:
392+
self._buffer.read_bytes(8)
393+
except BaseException:
394+
self._handshake_done = False
395+
self._version = None
396+
raise
377397
return version
378398

379399

tests/test_codec_signal_safety.py

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Signal-safety tests for MessageDecoder (issue 045).
1+
"""Signal-safety tests for MessageDecoder (issues 041 and 045).
22
33
``MessageDecoder.decode()`` and ``decode_continuation()`` both consume
44
bytes from the buffer via ``read_message()`` and then parse them via
@@ -15,9 +15,18 @@
1515
caller catches it at the top level and retries; the retry reads the
1616
*next* message boundary, silently losing exactly one message.
1717
18+
``MessageDecoder.decode_handshake()`` has an analogous hazard: the
19+
commit used to be (1) consume bytes, (2) set ``_version``, (3) set
20+
``_handshake_done``. A signal between (1) and (3) left the bytes
21+
consumed but state not marked; the retry re-peeked the first 8 bytes
22+
of the *next* message as a handshake and raised a misleading
23+
"Unsupported protocol version" error.
24+
1825
These tests use ``sys.settrace`` to inject a ``KeyboardInterrupt`` at
19-
a specific source line inside ``decode_bytes`` and assert that the
20-
decoder is left poisoned.
26+
a specific source line inside a decoder method — the most reliable
27+
way to exercise CPython's bytecode-boundary async-exception model
28+
without relying on wallclock timing — and assert the decoder ends up
29+
in a state the caller can detect and recover from, not silently torn.
2130
"""
2231

2332
from __future__ import annotations
@@ -31,7 +40,7 @@
3140
import pytest
3241

3342
from dqlitewire.codec import MessageDecoder
34-
from dqlitewire.constants import ResponseType
43+
from dqlitewire.constants import PROTOCOL_VERSION, ResponseType
3544
from dqlitewire.exceptions import DecodeError, ProtocolError
3645
from dqlitewire.messages import DbResponse, LeaderResponse
3746

@@ -65,10 +74,12 @@ def tracer(frame: FrameType, event: str, arg: object) -> Any:
6574

6675

6776
class TestDecodeSignalSafety:
68-
def test_keyboard_interrupt_inside_parse_leaves_decoder_poisoned(self) -> None:
69-
"""Regression for issue 045.
77+
"""Regression tests for issue 045 (consumed-but-unpoisoned window
78+
in ``decode()`` / ``decode_continuation()``).
79+
"""
7080

71-
A ``KeyboardInterrupt`` delivered while ``decode_bytes`` is
81+
def test_keyboard_interrupt_inside_parse_leaves_decoder_poisoned(self) -> None:
82+
"""A ``KeyboardInterrupt`` delivered while ``decode_bytes`` is
7283
parsing a consumed message used to propagate without
7384
poisoning the decoder, because ``except Exception`` does not
7485
catch ``BaseException``. The retry read the next message
@@ -177,3 +188,116 @@ def test_db_response_decodes_cleanly_without_interrupt() -> None:
177188
msg = dec.decode()
178189
assert isinstance(msg, DbResponse)
179190
assert msg.db_id == 42
191+
192+
193+
class TestDecodeHandshakeSignalSafety:
194+
"""Regression tests for issue 041 (hazard 2: single-threaded signal
195+
split inside ``decode_handshake``).
196+
197+
The pre-fix ``decode_handshake`` committed via three separate
198+
statements:
199+
200+
self._buffer.read_bytes(8) # consume bytes
201+
self._version = version # record version
202+
self._handshake_done = True # mark done
203+
204+
A ``KeyboardInterrupt`` delivered between ``read_bytes(8)`` and
205+
the final store left the buffer with the 8 handshake bytes
206+
consumed but ``_handshake_done`` still ``False``. On retry,
207+
``decode_handshake`` re-peeked ``peek_bytes(8)`` — but those were
208+
now the first 8 bytes of the **next** message, which almost
209+
always fail the ``_SUPPORTED_VERSIONS`` check. The caller saw a
210+
misleading "Unsupported protocol version" error pointing at what
211+
was actually legitimate message data.
212+
213+
The fix marks ``_handshake_done`` BEFORE consuming the bytes and
214+
reverts on failure. An interrupt between the mark and the consume
215+
leaves the buffer unconsumed and the state "already completed";
216+
the retry raises a deterministic error instead of silently
217+
misreading real data.
218+
"""
219+
220+
def test_keyboard_interrupt_post_consume_is_not_silently_misleading(
221+
self,
222+
) -> None:
223+
dec = MessageDecoder(is_request=True)
224+
# 8 valid handshake bytes followed by 8 bytes that WOULD fail
225+
# the version check if the retry misread them as a handshake.
226+
dec.feed(PROTOCOL_VERSION.to_bytes(8, "little") + b"\xff" * 8)
227+
228+
state = {"raised": False}
229+
230+
def tracer(frame: FrameType, event: str, arg: object) -> Any:
231+
if event != "line":
232+
return tracer
233+
if frame.f_code.co_name != "decode_handshake":
234+
return tracer
235+
if state["raised"]:
236+
return tracer
237+
try:
238+
with open(frame.f_code.co_filename) as f:
239+
src_line = f.readlines()[frame.f_lineno - 1]
240+
except OSError:
241+
return tracer
242+
# Inject at the `self._version = version` line — it is
243+
# present in both the pre-fix and post-fix layouts, and
244+
# in the pre-fix source it runs AFTER read_bytes(8).
245+
if "self._version = version" in src_line:
246+
state["raised"] = True
247+
raise KeyboardInterrupt("injected mid-commit")
248+
return tracer
249+
250+
sys.settrace(tracer)
251+
try:
252+
with contextlib.suppress(KeyboardInterrupt):
253+
dec.decode_handshake()
254+
finally:
255+
sys.settrace(None)
256+
257+
# After the torn window, a retry must NOT report
258+
# "Unsupported protocol version" for what were actually real
259+
# bytes of the next message (0xff*8). Acceptable outcomes:
260+
# 1. The fix reverted state — retry succeeds with the
261+
# original 8 handshake bytes still in the buffer.
262+
# 2. The fix set _handshake_done before the consume and
263+
# did not revert — retry raises "already completed".
264+
retry_err: Exception | None = None
265+
retry_result: int | None = None
266+
try:
267+
retry_result = dec.decode_handshake()
268+
except ProtocolError as e:
269+
retry_err = e
270+
271+
if retry_err is not None:
272+
assert "Unsupported" not in str(retry_err), (
273+
f"retry reported misleading error: {retry_err}"
274+
)
275+
else:
276+
assert retry_result == PROTOCOL_VERSION
277+
assert dec._handshake_done is True
278+
279+
def test_happy_path_handshake_still_works(self) -> None:
280+
"""Sanity: without any interrupt, decode_handshake completes
281+
normally and sets all three state bits.
282+
"""
283+
dec = MessageDecoder(is_request=True)
284+
dec.feed(PROTOCOL_VERSION.to_bytes(8, "little"))
285+
version = dec.decode_handshake()
286+
assert version == PROTOCOL_VERSION
287+
assert dec._handshake_done is True
288+
assert dec._version == PROTOCOL_VERSION
289+
290+
def test_unsupported_version_does_not_consume_bytes(self) -> None:
291+
"""Counter-test for issues 027 and 041: an unsupported version
292+
must leave the 8 peeked bytes in the buffer so that retry
293+
semantics remain deterministic.
294+
"""
295+
dec = MessageDecoder(is_request=True)
296+
bogus = (0xDEADBEEF).to_bytes(8, "little")
297+
dec.feed(bogus)
298+
299+
with pytest.raises(ProtocolError, match="Unsupported"):
300+
dec.decode_handshake()
301+
302+
assert dec._handshake_done is False
303+
assert dec._buffer.available() == 8

0 commit comments

Comments
 (0)