Skip to content

Commit 7068a34

Browse files
fix: plumb max_rows through MessageDecoder to RowsResponse
RowsResponse.decode_body accepts a max_rows cap, but MessageDecoder never forwarded it at its two call sites (decode_bytes and decode_continuation). Callers tightening max_message_size had no paired knob for in-memory row count. Add a max_rows constructor parameter, default to RowsResponse.DEFAULT_MAX_ROWS, validate max_rows >= 1, and forward at both call sites. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9d17383 commit 7068a34

2 files changed

Lines changed: 76 additions & 1 deletion

File tree

src/dqlitewire/codec.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
is_request: bool = False,
182182
version: int = PROTOCOL_VERSION,
183183
max_message_size: int = ReadBuffer.DEFAULT_MAX_MESSAGE_SIZE,
184+
max_rows: int = RowsResponse.DEFAULT_MAX_ROWS,
184185
) -> None:
185186
"""Initialize decoder.
186187
@@ -194,12 +195,18 @@ def __init__(
194195
max_message_size: Maximum allowed message size in bytes.
195196
Defaults to 64 MiB. Messages exceeding this limit are
196197
rejected with DecodeError.
198+
max_rows: Maximum number of rows permitted in a single
199+
``RowsResponse`` frame (including continuation frames).
200+
Defaults to ``RowsResponse.DEFAULT_MAX_ROWS``. Exceeding
201+
this limit raises ``DecodeError``.
197202
"""
198203
if not is_request and version not in _SUPPORTED_VERSIONS:
199204
raise ProtocolError(
200205
f"Unsupported protocol version: {version:#x}. "
201206
f"Supported: {', '.join(f'{v:#x}' for v in sorted(_SUPPORTED_VERSIONS))}"
202207
)
208+
if max_rows < 1:
209+
raise ValueError(f"max_rows must be >= 1, got {max_rows}")
203210
self._buffer = ReadBuffer(max_message_size=max_message_size)
204211
self._is_request = is_request
205212
self._type_map = REQUEST_TYPES if is_request else RESPONSE_TYPES
@@ -211,6 +218,7 @@ def __init__(
211218
# don't receive an inbound handshake, so they skip this check.
212219
self._handshake_done = not is_request
213220
self._continuation_expected = False
221+
self._max_rows = max_rows
214222

215223
@property
216224
def version(self) -> int | None:
@@ -334,7 +342,7 @@ def decode_continuation(self) -> RowsResponse | None:
334342
f"got type {header.msg_type}"
335343
)
336344

337-
result = RowsResponse.decode_body(body, schema=header.schema)
345+
result = RowsResponse.decode_body(body, schema=header.schema, max_rows=self._max_rows)
338346
if not result.has_more:
339347
self._continuation_expected = False
340348
return result
@@ -449,6 +457,11 @@ def decode_bytes(self, data: bytes) -> Message:
449457
):
450458
return LeaderResponse.decode_body_legacy(body)
451459

460+
# RowsResponse takes an extra ``max_rows`` cap; other classes
461+
# share the generic (body, schema) signature.
462+
if msg_class is RowsResponse:
463+
return RowsResponse.decode_body(body, schema=header.schema, max_rows=self._max_rows)
464+
452465
return msg_class.decode_body(body, schema=header.schema)
453466

454467
def decode_handshake(self) -> int | None:

tests/test_codec.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,68 @@ def test_decoder_default_max_message_size(self) -> None:
115115
decoder = MessageDecoder()
116116
assert decoder._buffer._max_message_size == ReadBuffer.DEFAULT_MAX_MESSAGE_SIZE
117117

118+
def test_decoder_custom_max_rows(self) -> None:
119+
"""231: max_rows should be forwarded to RowsResponse.decode_body.
120+
121+
The knob exists on RowsResponse.decode_body, but MessageDecoder
122+
previously omitted it at the call site, leaving the 1M default
123+
always in force regardless of caller configuration.
124+
"""
125+
from dqlitewire.constants import ValueType
126+
from dqlitewire.exceptions import DecodeError
127+
from dqlitewire.messages.responses import RowsResponse
128+
129+
msg = RowsResponse(
130+
column_names=["a"],
131+
column_types=[ValueType.INTEGER],
132+
row_types=[[ValueType.INTEGER]] * 3,
133+
rows=[[1], [2], [3]],
134+
)
135+
encoded = msg.encode()
136+
137+
decoder = MessageDecoder(max_rows=2)
138+
decoder.feed(encoded)
139+
with pytest.raises(DecodeError, match="reached maximum 2"):
140+
decoder.decode()
141+
142+
def test_decoder_default_max_rows(self) -> None:
143+
"""231: default max_rows should match RowsResponse.DEFAULT_MAX_ROWS."""
144+
from dqlitewire.messages.responses import RowsResponse
145+
146+
decoder = MessageDecoder()
147+
assert decoder._max_rows == RowsResponse.DEFAULT_MAX_ROWS
148+
149+
def test_decoder_rejects_zero_max_rows(self) -> None:
150+
"""231: max_rows < 1 should be rejected at construction time."""
151+
with pytest.raises(ValueError, match="max_rows must be >= 1"):
152+
MessageDecoder(max_rows=0)
153+
154+
def test_decoder_continuation_honors_max_rows(self) -> None:
155+
"""231: decode_continuation should also honor max_rows.
156+
157+
A multi-frame ROWS stream can exceed max_rows across frames; the
158+
per-frame cap still applies to each continuation response.
159+
"""
160+
from dqlitewire.constants import ValueType
161+
from dqlitewire.exceptions import DecodeError
162+
from dqlitewire.messages.responses import RowsResponse
163+
164+
cont = RowsResponse(
165+
column_names=["a"],
166+
column_types=[ValueType.INTEGER],
167+
row_types=[[ValueType.INTEGER]] * 3,
168+
rows=[[1], [2], [3]],
169+
has_more=False,
170+
)
171+
cont_bytes = cont.encode()
172+
173+
decoder = MessageDecoder(max_rows=2)
174+
# Mark decoder as mid-continuation so decode_continuation() runs.
175+
decoder._continuation_expected = True
176+
decoder.feed(cont_bytes)
177+
with pytest.raises(DecodeError, match="reached maximum 2"):
178+
decoder.decode_continuation()
179+
118180
def test_decode_handshake(self) -> None:
119181
decoder = MessageDecoder(is_request=True)
120182
decoder.feed(PROTOCOL_VERSION.to_bytes(8, "little"))

0 commit comments

Comments
 (0)