Skip to content

Commit d48a734

Browse files
fix: decode continuations using decode_body (matching C server wire format)
The C dqlite server always writes column_count + column_names in every ROWS message, including continuations sent after a PART marker. Python's decode_rows_continuation assumed continuations omit the column header and jumped straight to row data, causing parse errors against real server output. Fix decode_continuation() to call RowsResponse.decode_body() instead of decode_rows_continuation(), matching the actual wire format. Remove the now-unnecessary column_names and column_count parameters since they're in the frame itself. Update all callers and tests to use the new parameterless API with continuation frames that include the column header prefix. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 82fee8a commit d48a734

3 files changed

Lines changed: 70 additions & 56 deletions

File tree

src/dqlitewire/codec.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -244,27 +244,21 @@ def is_skipping(self) -> bool:
244244
"""True if still discarding bytes from an oversized message."""
245245
return self._buffer.is_skipping
246246

247-
def decode_continuation(
248-
self,
249-
column_names: list[str],
250-
column_count: int,
251-
max_rows: int = RowsResponse.DEFAULT_MAX_ROWS,
252-
) -> RowsResponse | None:
247+
def decode_continuation(self) -> RowsResponse | None:
253248
"""Decode a ROWS continuation message from the buffer.
254249
255-
After receiving a RowsResponse with has_more=True, the server sends
256-
additional ROWS messages containing only row data (no column_count or
257-
column_names prefix). Use this method instead of decode() for those
258-
continuation messages.
259-
260-
Args:
261-
column_names: Column names from the initial RowsResponse.
262-
column_count: Number of columns from the initial RowsResponse.
263-
max_rows: Maximum rows to decode per message.
250+
After receiving a RowsResponse with has_more=True, call this
251+
method to decode each subsequent ROWS frame. The C dqlite server
252+
sends continuation frames with the same body layout as the
253+
initial frame (column_count + column_names + rows + marker),
254+
so this method uses ``RowsResponse.decode_body`` — the same
255+
decoder used for the initial frame.
264256
265257
Returns None if no complete message is available.
266258
Raises ProtocolError if no continuation is in progress
267259
(use ``decode()`` for the initial message).
260+
Raises ProtocolError if the server sends a FailureResponse
261+
instead of a ROWS continuation (e.g., mid-stream I/O error).
268262
"""
269263
self._buffer._check_poisoned()
270264
if not self._continuation_expected:
@@ -276,14 +270,6 @@ def decode_continuation(
276270
if data is None:
277271
return None
278272

279-
# Bytes have been consumed — ANY failure here leaves the
280-
# buffer at an unknown offset and must poison the decoder.
281-
# Catch BaseException so that signal-delivered KeyboardInterrupt
282-
# (issue 045) and unexpected BaseException subclasses also
283-
# poison before propagating. `decode_body` implementations can
284-
# raise struct.error, ValueError, UnicodeDecodeError,
285-
# IndexError, etc., and all of them mean the stream is
286-
# desynchronized.
287273
try:
288274
header = Header.decode(data[:HEADER_SIZE])
289275
body = data[HEADER_SIZE : HEADER_SIZE + header.size_words * 8]
@@ -299,16 +285,11 @@ def decode_continuation(
299285
f"got type {header.msg_type}"
300286
)
301287

302-
result = RowsResponse.decode_rows_continuation(
303-
body, column_names, column_count, max_rows=max_rows
304-
)
288+
result = RowsResponse.decode_body(body, schema=header.schema)
305289
if not result.has_more:
306290
self._continuation_expected = False
307291
return result
308292
except BaseException as e:
309-
# poison() stores Exception | None; wrap non-Exception
310-
# BaseException subclasses so the poison cause is still a
311-
# real Exception we can inspect.
312293
self._buffer.poison(
313294
e
314295
if isinstance(e, Exception)

tests/test_codec.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,10 @@ def test_decode_continuation_roundtrip(self) -> None:
875875
msg1_bytes = header1.encode() + body1
876876

877877
# Build continuation ROWS message with DONE marker
878-
body2 = encode_row_header(types)
878+
# (C server always includes column_count + column_names)
879+
body2 = encode_uint64(2)
880+
body2 += encode_text("id") + encode_text("name")
881+
body2 += encode_row_header(types)
879882
body2 += encode_row_values([2, "bob"], types)
880883
body2 += encode_uint64(ROW_DONE_MARKER)
881884
header2 = Header(size_words=len(body2) // 8, msg_type=7, schema=0)
@@ -893,20 +896,59 @@ def test_decode_continuation_roundtrip(self) -> None:
893896
assert initial.rows[0] == [1, "alice"]
894897

895898
# Decode continuation
896-
continuation = decoder.decode_continuation(
897-
column_names=initial.column_names,
898-
column_count=len(initial.column_names),
899-
)
899+
continuation = decoder.decode_continuation()
900900
assert isinstance(continuation, RowsResponse)
901901
assert continuation.has_more is False
902902
assert len(continuation.rows) == 1
903903
assert continuation.rows[0] == [2, "bob"]
904904

905+
def test_decode_continuation_with_column_header(self) -> None:
906+
"""Continuation frames from the C server include column_count +
907+
column_names (same layout as the initial frame). Verify that
908+
decode_continuation handles this correctly.
909+
"""
910+
from dqlitewire.constants import ROW_DONE_MARKER, ROW_PART_MARKER, ValueType
911+
from dqlitewire.messages.base import Header
912+
from dqlitewire.messages.responses import RowsResponse
913+
from dqlitewire.tuples import encode_row_header, encode_row_values
914+
from dqlitewire.types import encode_text, encode_uint64
915+
916+
types = [ValueType.INTEGER, ValueType.TEXT]
917+
918+
# Initial ROWS message (PART marker)
919+
body1 = encode_uint64(2)
920+
body1 += encode_text("id") + encode_text("name")
921+
body1 += encode_row_header(types)
922+
body1 += encode_row_values([1, "alice"], types)
923+
body1 += encode_uint64(ROW_PART_MARKER)
924+
h1 = Header(size_words=len(body1) // 8, msg_type=7, schema=0)
925+
926+
# Continuation WITH column header (matching C server output)
927+
body2 = encode_uint64(2)
928+
body2 += encode_text("id") + encode_text("name")
929+
body2 += encode_row_header(types)
930+
body2 += encode_row_values([2, "bob"], types)
931+
body2 += encode_uint64(ROW_DONE_MARKER)
932+
h2 = Header(size_words=len(body2) // 8, msg_type=7, schema=0)
933+
934+
decoder = MessageDecoder(is_request=False)
935+
decoder.feed(h1.encode() + body1 + h2.encode() + body2)
936+
937+
initial = decoder.decode()
938+
assert isinstance(initial, RowsResponse)
939+
assert initial.has_more is True
940+
assert initial.rows[0] == [1, "alice"]
941+
942+
cont = decoder.decode_continuation()
943+
assert isinstance(cont, RowsResponse)
944+
assert cont.has_more is False
945+
assert cont.rows[0] == [2, "bob"]
946+
905947
def test_decode_continuation_returns_none_when_no_data(self) -> None:
906948
"""decode_continuation should return None when no message is available."""
907949
decoder = MessageDecoder(is_request=False)
908950
decoder._continuation_expected = True
909-
result = decoder.decode_continuation(column_names=["x"], column_count=1)
951+
result = decoder.decode_continuation()
910952
assert result is None
911953

912954
def test_decode_continuation_raises_on_failure_response(self) -> None:
@@ -922,7 +964,7 @@ def test_decode_continuation_raises_on_failure_response(self) -> None:
922964
decoder.feed(failure_bytes)
923965

924966
with pytest.raises(ProtocolError, match="disk I/O error") as exc_info:
925-
decoder.decode_continuation(column_names=["id"], column_count=1)
967+
decoder.decode_continuation()
926968
# Must be a ProtocolError, NOT a DecodeError (which would mean the
927969
# failure body was misinterpreted as row data).
928970
assert type(exc_info.value) is ProtocolError
@@ -939,7 +981,7 @@ def test_decode_continuation_raises_on_unexpected_type(self) -> None:
939981
decoder.feed(result_bytes)
940982

941983
with pytest.raises(ProtocolError, match="Expected ROWS continuation"):
942-
decoder.decode_continuation(column_names=["id"], column_count=1)
984+
decoder.decode_continuation()
943985

944986

945987
class TestDecoderContinuationExpected:
@@ -1007,8 +1049,10 @@ def test_decode_continuation_clears_flag(self) -> None:
10071049
body1 += encode_uint64(ROW_PART_MARKER)
10081050
h1 = Header(size_words=len(body1) // 8, msg_type=7, schema=0)
10091051

1010-
# Continuation frame (has_more=False)
1011-
body2 = encode_row_header(types)
1052+
# Continuation frame (has_more=False) — includes column header
1053+
body2 = encode_uint64(1)
1054+
body2 += encode_text("id")
1055+
body2 += encode_row_header(types)
10121056
body2 += encode_row_values([2], types)
10131057
body2 += encode_uint64(ROW_DONE_MARKER)
10141058
h2 = Header(size_words=len(body2) // 8, msg_type=7, schema=0)
@@ -1024,10 +1068,7 @@ def test_decode_continuation_clears_flag(self) -> None:
10241068
assert isinstance(initial, RowsResponse) and initial.has_more
10251069

10261070
# decode continuation
1027-
cont = decoder.decode_continuation(
1028-
column_names=initial.column_names,
1029-
column_count=len(initial.column_names),
1030-
)
1071+
cont = decoder.decode_continuation()
10311072
assert isinstance(cont, RowsResponse) and not cont.has_more
10321073

10331074
# Now decode() should work again
@@ -1107,7 +1148,7 @@ def test_decode_continuation_raises_when_not_expected(self) -> None:
11071148
decoder.feed(result.encode())
11081149

11091150
with pytest.raises(ProtocolError, match="no ROWS continuation"):
1110-
decoder.decode_continuation(column_names=["x"], column_count=1)
1151+
decoder.decode_continuation()
11111152

11121153
# The message must NOT have been consumed
11131154
assert decoder.has_message(), (
@@ -1339,7 +1380,7 @@ def test_decode_continuation_poisons_on_wrong_type(self) -> None:
13391380
decoder.feed(empty)
13401381

13411382
with pytest.raises(ProtocolError, match="Expected ROWS continuation"):
1342-
decoder.decode_continuation(column_names=["x"], column_count=1)
1383+
decoder.decode_continuation()
13431384
assert decoder.is_poisoned is True
13441385

13451386
# Subsequent decode() also fails poisoned.

tests/test_codec_signal_safety.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ def test_keyboard_interrupt_in_decode_continuation_poisons(self) -> None:
126126
"""Same hazard as decode(): decode_continuation() also needs
127127
BaseException handling.
128128
"""
129-
from dqlitewire.messages.responses import RowsResponse
130-
131129
# Build a minimal ROWS continuation body. We don't actually
132130
# reach the parser — the tracer injects the interrupt before
133131
# any real parsing runs — so the frame just needs to carry
@@ -137,6 +135,7 @@ def test_keyboard_interrupt_in_decode_continuation_poisons(self) -> None:
137135
frame = header + body
138136

139137
dec = MessageDecoder(is_request=False)
138+
dec._continuation_expected = True
140139
dec.feed(frame + frame)
141140

142141
# Inject inside decode_continuation itself (the wrapper
@@ -146,11 +145,7 @@ def test_keyboard_interrupt_in_decode_continuation_poisons(self) -> None:
146145
sys.settrace(tracer)
147146
try:
148147
with contextlib.suppress(KeyboardInterrupt):
149-
dec.decode_continuation(
150-
column_names=["a"],
151-
column_count=1,
152-
max_rows=RowsResponse.DEFAULT_MAX_ROWS,
153-
)
148+
dec.decode_continuation()
154149
finally:
155150
sys.settrace(None)
156151

@@ -162,10 +157,7 @@ def test_keyboard_interrupt_in_decode_continuation_poisons(self) -> None:
162157
# still at a valid offset.
163158
if dec.is_poisoned:
164159
with pytest.raises(ProtocolError, match="poisoned"):
165-
dec.decode_continuation(
166-
column_names=["a"],
167-
column_count=1,
168-
)
160+
dec.decode_continuation()
169161
else:
170162
# Buffer offset must still be sensible.
171163
assert dec._buffer.available() >= 0

0 commit comments

Comments
 (0)