@@ -939,6 +939,149 @@ def test_decode_continuation_raises_on_unexpected_type(self) -> None:
939939 decoder .decode_continuation (column_names = ["id" ], column_count = 1 )
940940
941941
942+ class TestDecoderContinuationExpected :
943+ """Regression tests for issue 058.
944+
945+ When ``decode()`` returns a ``RowsResponse`` with ``has_more=True``,
946+ the decoder enters a "continuation expected" state. Calling ``decode()``
947+ again before draining all continuations via ``decode_continuation()``
948+ is a protocol error — the next frame in the buffer is a continuation
949+ (no column header prefix), so ``decode()`` would misparse it. The
950+ ``_continuation_expected`` flag makes this misuse fail loudly with
951+ ``ProtocolError`` instead of producing silent stream desynchronization.
952+ """
953+
954+ def test_decode_raises_when_continuation_expected (self ) -> None :
955+ """decode() must refuse while a ROWS continuation is in progress."""
956+ from dqlitewire .constants import ROW_PART_MARKER , ValueType
957+ from dqlitewire .exceptions import ProtocolError
958+ from dqlitewire .messages .base import Header
959+ from dqlitewire .messages .responses import RowsResponse
960+ from dqlitewire .tuples import encode_row_header , encode_row_values
961+ from dqlitewire .types import encode_text , encode_uint64
962+
963+ types = [ValueType .INTEGER ]
964+
965+ # Build a RowsResponse with has_more=True
966+ body = encode_uint64 (1 ) # column_count
967+ body += encode_text ("id" )
968+ body += encode_row_header (types )
969+ body += encode_row_values ([1 ], types )
970+ body += encode_uint64 (ROW_PART_MARKER )
971+ header = Header (size_words = len (body ) // 8 , msg_type = 7 , schema = 0 )
972+ msg_bytes = header .encode () + body
973+
974+ # Also feed a second (standalone) message that decode() would try to read
975+ second = ResultResponse (last_insert_id = 0 , rows_affected = 0 ).encode ()
976+
977+ decoder = MessageDecoder (is_request = False )
978+ decoder .feed (msg_bytes + second )
979+
980+ # First decode() returns the initial RowsResponse with has_more=True
981+ result = decoder .decode ()
982+ assert isinstance (result , RowsResponse )
983+ assert result .has_more is True
984+
985+ # Second decode() must raise — we're in "continuation expected" state
986+ with pytest .raises (ProtocolError , match = "continuation" ):
987+ decoder .decode ()
988+
989+ def test_decode_continuation_clears_flag (self ) -> None :
990+ """After draining all continuations (has_more=False), decode() works again."""
991+ from dqlitewire .constants import ROW_DONE_MARKER , ROW_PART_MARKER , ValueType
992+ from dqlitewire .messages .base import Header
993+ from dqlitewire .messages .responses import RowsResponse
994+ from dqlitewire .tuples import encode_row_header , encode_row_values
995+ from dqlitewire .types import encode_text , encode_uint64
996+
997+ types = [ValueType .INTEGER ]
998+
999+ # Initial frame (has_more=True)
1000+ body1 = encode_uint64 (1 )
1001+ body1 += encode_text ("id" )
1002+ body1 += encode_row_header (types )
1003+ body1 += encode_row_values ([1 ], types )
1004+ body1 += encode_uint64 (ROW_PART_MARKER )
1005+ h1 = Header (size_words = len (body1 ) // 8 , msg_type = 7 , schema = 0 )
1006+
1007+ # Continuation frame (has_more=False)
1008+ body2 = encode_row_header (types )
1009+ body2 += encode_row_values ([2 ], types )
1010+ body2 += encode_uint64 (ROW_DONE_MARKER )
1011+ h2 = Header (size_words = len (body2 ) // 8 , msg_type = 7 , schema = 0 )
1012+
1013+ # Normal message after the ROWS sequence
1014+ normal = ResultResponse (last_insert_id = 5 , rows_affected = 3 ).encode ()
1015+
1016+ decoder = MessageDecoder (is_request = False )
1017+ decoder .feed (h1 .encode () + body1 + h2 .encode () + body2 + normal )
1018+
1019+ # decode initial
1020+ initial = decoder .decode ()
1021+ assert isinstance (initial , RowsResponse ) and initial .has_more
1022+
1023+ # decode continuation
1024+ cont = decoder .decode_continuation (
1025+ column_names = initial .column_names ,
1026+ column_count = len (initial .column_names ),
1027+ )
1028+ assert isinstance (cont , RowsResponse ) and not cont .has_more
1029+
1030+ # Now decode() should work again
1031+ result = decoder .decode ()
1032+ assert isinstance (result , ResultResponse )
1033+ assert result .last_insert_id == 5
1034+
1035+ def test_reset_clears_continuation_expected (self ) -> None :
1036+ """reset() must clear the continuation-expected flag."""
1037+ from dqlitewire .constants import ROW_PART_MARKER , ValueType
1038+ from dqlitewire .messages .base import Header
1039+ from dqlitewire .messages .responses import RowsResponse
1040+ from dqlitewire .tuples import encode_row_header , encode_row_values
1041+ from dqlitewire .types import encode_text , encode_uint64
1042+
1043+ types = [ValueType .INTEGER ]
1044+ body = encode_uint64 (1 )
1045+ body += encode_text ("id" )
1046+ body += encode_row_header (types )
1047+ body += encode_row_values ([1 ], types )
1048+ body += encode_uint64 (ROW_PART_MARKER )
1049+ header = Header (size_words = len (body ) // 8 , msg_type = 7 , schema = 0 )
1050+
1051+ decoder = MessageDecoder (is_request = False )
1052+ decoder .feed (header .encode () + body )
1053+ result = decoder .decode ()
1054+ assert isinstance (result , RowsResponse ) and result .has_more
1055+
1056+ # Reset should clear the flag
1057+ decoder .reset ()
1058+ normal = ResultResponse (last_insert_id = 0 , rows_affected = 0 ).encode ()
1059+ decoder .feed (normal )
1060+ msg = decoder .decode ()
1061+ assert isinstance (msg , ResultResponse )
1062+
1063+ def test_has_more_false_does_not_set_flag (self ) -> None :
1064+ """A RowsResponse with has_more=False should NOT set the flag."""
1065+ from dqlitewire .messages .responses import RowsResponse
1066+
1067+ decoder = MessageDecoder (is_request = False )
1068+ msg = RowsResponse (
1069+ column_names = ["x" ],
1070+ column_types = [1 ],
1071+ rows = [[1 ]],
1072+ has_more = False ,
1073+ )
1074+ normal = ResultResponse (last_insert_id = 0 , rows_affected = 0 ).encode ()
1075+ decoder .feed (msg .encode () + normal )
1076+
1077+ result = decoder .decode ()
1078+ assert isinstance (result , RowsResponse ) and not result .has_more
1079+
1080+ # decode() should work fine — no continuation expected
1081+ result2 = decoder .decode ()
1082+ assert isinstance (result2 , ResultResponse )
1083+
1084+
9421085class TestDecoderSkipMessage :
9431086 """Test skip_message() and is_skipping on MessageDecoder."""
9441087
0 commit comments