Skip to content

Commit 1ac48ec

Browse files
fix: enforce ROWS continuation drain invariant via state flag
Add a `_continuation_expected` flag to `MessageDecoder`. When `decode()` returns a `RowsResponse` with `has_more=True`, the flag is set. While set, subsequent `decode()` calls raise `ProtocolError` — the next frame in the buffer is a ROWS continuation (no column header prefix), so `decode()` would misparse it. Callers must use `decode_continuation()` until `has_more` is `False`, or `reset()`. This matches Go's `protocol.Protocol` pattern where `Call()` holds the mutex across the entire ROWS response sequence, and `Rows.Close()` drains remaining frames before releasing the connection. The flag converts silent stream desynchronization (decode reads continuation data as a new message) into a loud `ProtocolError` at the point of misuse. The flag is cleared by `decode_continuation()` when the final frame (`has_more=False`) is decoded, and by `reset()` for abandoning the stream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a22da9e commit 1ac48ec

2 files changed

Lines changed: 164 additions & 2 deletions

File tree

src/dqlitewire/codec.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self, is_request: bool = False, version: int = PROTOCOL_VERSION) ->
180180
# handshake before decoding any messages. Response decoders (client-side)
181181
# don't receive an inbound handshake, so they skip this check.
182182
self._handshake_done = not is_request
183+
self._continuation_expected = False
183184

184185
@property
185186
def version(self) -> int | None:
@@ -204,6 +205,7 @@ def reset(self) -> None:
204205
handshake state to "not yet received". Use this after a reconnect.
205206
"""
206207
self._buffer.reset()
208+
self._continuation_expected = False
207209
if self._is_request:
208210
self._handshake_done = False
209211
self._version = None
@@ -280,9 +282,12 @@ def decode_continuation(
280282
f"got type {header.msg_type}"
281283
)
282284

283-
return RowsResponse.decode_rows_continuation(
285+
result = RowsResponse.decode_rows_continuation(
284286
body, column_names, column_count, max_rows=max_rows
285287
)
288+
if not result.has_more:
289+
self._continuation_expected = False
290+
return result
286291
except BaseException as e:
287292
# poison() stores Exception | None; wrap non-Exception
288293
# BaseException subclasses so the poison cause is still a
@@ -300,8 +305,17 @@ def decode(self) -> Message | None:
300305
Returns None if no complete message is available.
301306
Raises ProtocolError if called on a request decoder before decode_handshake().
302307
Raises ProtocolError if the decoder is poisoned.
308+
Raises ProtocolError if a ROWS continuation is in progress
309+
(call ``decode_continuation()`` until ``has_more`` is ``False``,
310+
or ``reset()`` to abandon the stream).
303311
"""
304312
self._buffer._check_poisoned()
313+
if self._continuation_expected:
314+
raise ProtocolError(
315+
"Cannot decode a new message while a ROWS continuation "
316+
"is in progress. Call decode_continuation() until "
317+
"has_more is False, or call reset() to abandon the stream."
318+
)
305319
if not self._handshake_done:
306320
raise ProtocolError(
307321
"Protocol handshake not yet received. Call decode_handshake() before decode()."
@@ -321,7 +335,7 @@ def decode(self) -> Message | None:
321335
# struct.error, ValueError, UnicodeDecodeError, IndexError,
322336
# etc., and all of them mean the stream is desynchronized.
323337
try:
324-
return self.decode_bytes(data)
338+
msg = self.decode_bytes(data)
325339
except BaseException as e:
326340
# poison() stores Exception | None; wrap non-Exception
327341
# BaseException subclasses so the poison cause is still a
@@ -333,6 +347,11 @@ def decode(self) -> Message | None:
333347
)
334348
raise
335349

350+
if isinstance(msg, RowsResponse) and msg.has_more:
351+
self._continuation_expected = True
352+
353+
return msg
354+
336355
def decode_bytes(self, data: bytes) -> Message:
337356
"""Decode a message from bytes.
338357

tests/test_codec.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,149 @@ def test_decode_continuation_raises_on_unexpected_type(self) -> None:
939939
decoder.decode_continuation(column_names=["id"], column_count=1)
940940

941941

942+
class TestDecoderContinuationExpected:
943+
"""Regression tests for issue 058.
944+
945+
When ``decode()`` returns a ``RowsResponse`` with ``has_more=True``,
946+
the decoder enters a "continuation expected" state. Calling ``decode()``
947+
again before draining all continuations via ``decode_continuation()``
948+
is a protocol error — the next frame in the buffer is a continuation
949+
(no column header prefix), so ``decode()`` would misparse it. The
950+
``_continuation_expected`` flag makes this misuse fail loudly with
951+
``ProtocolError`` instead of producing silent stream desynchronization.
952+
"""
953+
954+
def test_decode_raises_when_continuation_expected(self) -> None:
955+
"""decode() must refuse while a ROWS continuation is in progress."""
956+
from dqlitewire.constants import ROW_PART_MARKER, ValueType
957+
from dqlitewire.exceptions import ProtocolError
958+
from dqlitewire.messages.base import Header
959+
from dqlitewire.messages.responses import RowsResponse
960+
from dqlitewire.tuples import encode_row_header, encode_row_values
961+
from dqlitewire.types import encode_text, encode_uint64
962+
963+
types = [ValueType.INTEGER]
964+
965+
# Build a RowsResponse with has_more=True
966+
body = encode_uint64(1) # column_count
967+
body += encode_text("id")
968+
body += encode_row_header(types)
969+
body += encode_row_values([1], types)
970+
body += encode_uint64(ROW_PART_MARKER)
971+
header = Header(size_words=len(body) // 8, msg_type=7, schema=0)
972+
msg_bytes = header.encode() + body
973+
974+
# Also feed a second (standalone) message that decode() would try to read
975+
second = ResultResponse(last_insert_id=0, rows_affected=0).encode()
976+
977+
decoder = MessageDecoder(is_request=False)
978+
decoder.feed(msg_bytes + second)
979+
980+
# First decode() returns the initial RowsResponse with has_more=True
981+
result = decoder.decode()
982+
assert isinstance(result, RowsResponse)
983+
assert result.has_more is True
984+
985+
# Second decode() must raise — we're in "continuation expected" state
986+
with pytest.raises(ProtocolError, match="continuation"):
987+
decoder.decode()
988+
989+
def test_decode_continuation_clears_flag(self) -> None:
990+
"""After draining all continuations (has_more=False), decode() works again."""
991+
from dqlitewire.constants import ROW_DONE_MARKER, ROW_PART_MARKER, ValueType
992+
from dqlitewire.messages.base import Header
993+
from dqlitewire.messages.responses import RowsResponse
994+
from dqlitewire.tuples import encode_row_header, encode_row_values
995+
from dqlitewire.types import encode_text, encode_uint64
996+
997+
types = [ValueType.INTEGER]
998+
999+
# Initial frame (has_more=True)
1000+
body1 = encode_uint64(1)
1001+
body1 += encode_text("id")
1002+
body1 += encode_row_header(types)
1003+
body1 += encode_row_values([1], types)
1004+
body1 += encode_uint64(ROW_PART_MARKER)
1005+
h1 = Header(size_words=len(body1) // 8, msg_type=7, schema=0)
1006+
1007+
# Continuation frame (has_more=False)
1008+
body2 = encode_row_header(types)
1009+
body2 += encode_row_values([2], types)
1010+
body2 += encode_uint64(ROW_DONE_MARKER)
1011+
h2 = Header(size_words=len(body2) // 8, msg_type=7, schema=0)
1012+
1013+
# Normal message after the ROWS sequence
1014+
normal = ResultResponse(last_insert_id=5, rows_affected=3).encode()
1015+
1016+
decoder = MessageDecoder(is_request=False)
1017+
decoder.feed(h1.encode() + body1 + h2.encode() + body2 + normal)
1018+
1019+
# decode initial
1020+
initial = decoder.decode()
1021+
assert isinstance(initial, RowsResponse) and initial.has_more
1022+
1023+
# decode continuation
1024+
cont = decoder.decode_continuation(
1025+
column_names=initial.column_names,
1026+
column_count=len(initial.column_names),
1027+
)
1028+
assert isinstance(cont, RowsResponse) and not cont.has_more
1029+
1030+
# Now decode() should work again
1031+
result = decoder.decode()
1032+
assert isinstance(result, ResultResponse)
1033+
assert result.last_insert_id == 5
1034+
1035+
def test_reset_clears_continuation_expected(self) -> None:
1036+
"""reset() must clear the continuation-expected flag."""
1037+
from dqlitewire.constants import ROW_PART_MARKER, ValueType
1038+
from dqlitewire.messages.base import Header
1039+
from dqlitewire.messages.responses import RowsResponse
1040+
from dqlitewire.tuples import encode_row_header, encode_row_values
1041+
from dqlitewire.types import encode_text, encode_uint64
1042+
1043+
types = [ValueType.INTEGER]
1044+
body = encode_uint64(1)
1045+
body += encode_text("id")
1046+
body += encode_row_header(types)
1047+
body += encode_row_values([1], types)
1048+
body += encode_uint64(ROW_PART_MARKER)
1049+
header = Header(size_words=len(body) // 8, msg_type=7, schema=0)
1050+
1051+
decoder = MessageDecoder(is_request=False)
1052+
decoder.feed(header.encode() + body)
1053+
result = decoder.decode()
1054+
assert isinstance(result, RowsResponse) and result.has_more
1055+
1056+
# Reset should clear the flag
1057+
decoder.reset()
1058+
normal = ResultResponse(last_insert_id=0, rows_affected=0).encode()
1059+
decoder.feed(normal)
1060+
msg = decoder.decode()
1061+
assert isinstance(msg, ResultResponse)
1062+
1063+
def test_has_more_false_does_not_set_flag(self) -> None:
1064+
"""A RowsResponse with has_more=False should NOT set the flag."""
1065+
from dqlitewire.messages.responses import RowsResponse
1066+
1067+
decoder = MessageDecoder(is_request=False)
1068+
msg = RowsResponse(
1069+
column_names=["x"],
1070+
column_types=[1],
1071+
rows=[[1]],
1072+
has_more=False,
1073+
)
1074+
normal = ResultResponse(last_insert_id=0, rows_affected=0).encode()
1075+
decoder.feed(msg.encode() + normal)
1076+
1077+
result = decoder.decode()
1078+
assert isinstance(result, RowsResponse) and not result.has_more
1079+
1080+
# decode() should work fine — no continuation expected
1081+
result2 = decoder.decode()
1082+
assert isinstance(result2, ResultResponse)
1083+
1084+
9421085
class TestDecoderSkipMessage:
9431086
"""Test skip_message() and is_skipping on MessageDecoder."""
9441087

0 commit comments

Comments
 (0)