Skip to content

Commit 0b87fea

Browse files
Cover _read_continuation's wire-decode error wrap
The initial-response path has tests for the _WireProtocolError -> client ProtocolError wrap; the continuation- frame path (RowsResponse has_more=True, then read the next frame) has the same wrap but was not covered. Add a test that feeds a valid initial frame followed by garbage continuation bytes and asserts the resulting error is the client-layer ProtocolError, so callers catching dqliteclient.ProtocolError across multi-frame streaming reads continue to see the expected taxonomy. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 038fb58 commit 0b87fea

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Regression test for the ``_read_continuation`` error wrap.
2+
3+
The initial-response decode path wraps ``_WireProtocolError`` into the
4+
client-layer ``ProtocolError`` (tested elsewhere). The continuation-
5+
frame path (``_read_continuation``) has the same wrap at protocol.py
6+
lines 383-384 but was not covered by a dedicated test. A regression
7+
would leak the wire exception out of the client layer, breaking
8+
``except ProtocolError`` boundaries for callers that assemble
9+
multi-frame result sets.
10+
"""
11+
12+
from unittest.mock import AsyncMock, MagicMock
13+
14+
import pytest
15+
16+
from dqliteclient.exceptions import ProtocolError
17+
from dqliteclient.protocol import DqliteProtocol
18+
from dqlitewire.constants import ROW_PART_MARKER
19+
from dqlitewire.messages import RowsResponse
20+
21+
22+
@pytest.fixture
23+
def protocol() -> DqliteProtocol:
24+
reader = AsyncMock()
25+
writer = MagicMock()
26+
writer.drain = AsyncMock()
27+
writer.close = MagicMock()
28+
writer.wait_closed = AsyncMock()
29+
return DqliteProtocol(reader, writer, timeout=1.0)
30+
31+
32+
class TestReadContinuationWrapsWireError:
33+
async def test_malformed_continuation_bytes_raise_client_protocol_error(
34+
self, protocol: DqliteProtocol
35+
) -> None:
36+
"""Feed a valid initial ``RowsResponse(has_more=True)`` followed
37+
by garbage continuation bytes; the resulting wire-level
38+
``ProtocolError`` must be wrapped into the client layer's
39+
``ProtocolError`` so ``except dqliteclient.ProtocolError``
40+
continues to work.
41+
"""
42+
import struct
43+
44+
from dqlitewire.constants import ValueType
45+
from dqlitewire.types import encode_uint64
46+
47+
# Initial frame: a 1-column RowsResponse with one row, followed
48+
# by a PART marker to signal more rows in the next frame.
49+
initial = RowsResponse(
50+
column_names=["x"],
51+
column_types=[ValueType.INTEGER],
52+
rows=[(1,)],
53+
row_types=[(ValueType.INTEGER,)],
54+
has_more=False, # encoded below manually
55+
).encode()
56+
# Swap the trailing DONE marker to PART so the decoder expects
57+
# a continuation frame.
58+
part_marker = encode_uint64(ROW_PART_MARKER)
59+
done_marker_prefix = initial[:-8]
60+
frame_with_part = done_marker_prefix + part_marker
61+
62+
# A continuation frame starts with a header: size_words (uint32),
63+
# msg_type (uint8), schema (uint8), reserved (uint16), then body.
64+
# Feed 40 bytes of random garbage as a "continuation frame".
65+
garbage_header = struct.pack("<IBBH", 1, 0, 0, 0)
66+
garbage_body = b"\x00" * 8
67+
garbage_frame = garbage_header + garbage_body
68+
69+
# Reader returns the valid initial then the garbage in one shot.
70+
protocol._reader.read = AsyncMock( # type: ignore[attr-defined]
71+
side_effect=[frame_with_part + garbage_frame, b""]
72+
)
73+
74+
protocol._handshake_done = True
75+
# Kick off a query; any read past the PART marker should trip
76+
# the wire decoder and surface as client ProtocolError via
77+
# _read_continuation.
78+
with pytest.raises(ProtocolError):
79+
await protocol.query_sql(1, "SELECT 1")

0 commit comments

Comments
 (0)