Skip to content

Commit 2b29b42

Browse files
Add decode_continuation() to MessageDecoder for multi-part ROWS responses
When a server streams large result sets with PART markers, continuation messages have the same ROWS type code but no column_count prefix. Calling decode() on these would misinterpret row data as a column count. The new decode_continuation() method reads the next message and delegates to RowsResponse.decode_rows_continuation() with the caller-provided column metadata from the initial response. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 19ebbcb commit 2b29b42

2 files changed

Lines changed: 92 additions & 0 deletions

File tree

src/dqlitewire/codec.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,36 @@ def is_skipping(self) -> bool:
178178
"""True if still discarding bytes from an oversized message."""
179179
return self._buffer.is_skipping
180180

181+
def decode_continuation(
182+
self,
183+
column_names: list[str],
184+
column_count: int,
185+
max_rows: int = RowsResponse.DEFAULT_MAX_ROWS,
186+
) -> RowsResponse | None:
187+
"""Decode a ROWS continuation message from the buffer.
188+
189+
After receiving a RowsResponse with has_more=True, the server sends
190+
additional ROWS messages containing only row data (no column_count or
191+
column_names prefix). Use this method instead of decode() for those
192+
continuation messages.
193+
194+
Args:
195+
column_names: Column names from the initial RowsResponse.
196+
column_count: Number of columns from the initial RowsResponse.
197+
max_rows: Maximum rows to decode per message.
198+
199+
Returns None if no complete message is available.
200+
"""
201+
data = self._buffer.read_message()
202+
if data is None:
203+
return None
204+
205+
header = Header.decode(data[:HEADER_SIZE])
206+
body = data[HEADER_SIZE : HEADER_SIZE + header.size_words * 8]
207+
return RowsResponse.decode_rows_continuation(
208+
body, column_names, column_count, max_rows=max_rows
209+
)
210+
181211
def decode(self) -> Message | None:
182212
"""Decode the next message from the buffer.
183213

tests/test_codec.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,68 @@ def test_stmt_response_v0_with_trailing_data_not_detected_as_v1(self) -> None:
553553
assert decoded.tail_offset is None
554554

555555

556+
class TestDecoderContinuation:
557+
"""Test decode_continuation() for multi-part ROWS responses."""
558+
559+
def test_decode_continuation_exists(self) -> None:
560+
"""MessageDecoder should have a decode_continuation method."""
561+
decoder = MessageDecoder(is_request=False)
562+
assert hasattr(decoder, "decode_continuation")
563+
564+
def test_decode_continuation_roundtrip(self) -> None:
565+
"""Multi-part ROWS: initial decode() + continuation decode_continuation()."""
566+
from dqlitewire.constants import ROW_DONE_MARKER, ROW_PART_MARKER, ValueType
567+
from dqlitewire.messages.base import Header
568+
from dqlitewire.messages.responses import RowsResponse
569+
from dqlitewire.tuples import encode_row_header, encode_row_values
570+
from dqlitewire.types import encode_text, encode_uint64
571+
572+
types = [ValueType.INTEGER, ValueType.TEXT]
573+
574+
# Build initial ROWS message with PART marker
575+
body1 = encode_uint64(2) # column_count
576+
body1 += encode_text("id") + encode_text("name")
577+
body1 += encode_row_header(types)
578+
body1 += encode_row_values([1, "alice"], types)
579+
body1 += encode_uint64(ROW_PART_MARKER)
580+
header1 = Header(size_words=len(body1) // 8, msg_type=7, schema=0)
581+
msg1_bytes = header1.encode() + body1
582+
583+
# Build continuation ROWS message with DONE marker
584+
body2 = encode_row_header(types)
585+
body2 += encode_row_values([2, "bob"], types)
586+
body2 += encode_uint64(ROW_DONE_MARKER)
587+
header2 = Header(size_words=len(body2) // 8, msg_type=7, schema=0)
588+
msg2_bytes = header2.encode() + body2
589+
590+
# Feed both messages
591+
decoder = MessageDecoder(is_request=False)
592+
decoder.feed(msg1_bytes + msg2_bytes)
593+
594+
# Decode initial response
595+
initial = decoder.decode()
596+
assert isinstance(initial, RowsResponse)
597+
assert initial.has_more is True
598+
assert len(initial.rows) == 1
599+
assert initial.rows[0] == [1, "alice"]
600+
601+
# Decode continuation
602+
continuation = decoder.decode_continuation(
603+
column_names=initial.column_names,
604+
column_count=len(initial.column_names),
605+
)
606+
assert isinstance(continuation, RowsResponse)
607+
assert continuation.has_more is False
608+
assert len(continuation.rows) == 1
609+
assert continuation.rows[0] == [2, "bob"]
610+
611+
def test_decode_continuation_returns_none_when_no_data(self) -> None:
612+
"""decode_continuation should return None when no message is available."""
613+
decoder = MessageDecoder(is_request=False)
614+
result = decoder.decode_continuation(column_names=["x"], column_count=1)
615+
assert result is None
616+
617+
556618
class TestDecoderSkipMessage:
557619
"""Test skip_message() and is_skipping on MessageDecoder."""
558620

0 commit comments

Comments
 (0)