|
| 1 | +"""Regression test for the ``_read_continuation`` error wrap. |
| 2 | +
|
| 3 | +The initial-response decode path wraps ``_WireProtocolError`` into the |
| 4 | +client-layer ``ProtocolError`` (tested elsewhere). The continuation- |
| 5 | +frame path (``_read_continuation``) has the same wrap at protocol.py |
| 6 | +lines 383-384 but was not covered by a dedicated test. A regression |
| 7 | +would leak the wire exception out of the client layer, breaking |
| 8 | +``except ProtocolError`` boundaries for callers that assemble |
| 9 | +multi-frame result sets. |
| 10 | +""" |
| 11 | + |
| 12 | +from unittest.mock import AsyncMock, MagicMock |
| 13 | + |
| 14 | +import pytest |
| 15 | + |
| 16 | +from dqliteclient.exceptions import ProtocolError |
| 17 | +from dqliteclient.protocol import DqliteProtocol |
| 18 | +from dqlitewire.constants import ROW_PART_MARKER |
| 19 | +from dqlitewire.messages import RowsResponse |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture |
| 23 | +def protocol() -> DqliteProtocol: |
| 24 | + reader = AsyncMock() |
| 25 | + writer = MagicMock() |
| 26 | + writer.drain = AsyncMock() |
| 27 | + writer.close = MagicMock() |
| 28 | + writer.wait_closed = AsyncMock() |
| 29 | + return DqliteProtocol(reader, writer, timeout=1.0) |
| 30 | + |
| 31 | + |
| 32 | +class TestReadContinuationWrapsWireError: |
| 33 | + async def test_malformed_continuation_bytes_raise_client_protocol_error( |
| 34 | + self, protocol: DqliteProtocol |
| 35 | + ) -> None: |
| 36 | + """Feed a valid initial ``RowsResponse(has_more=True)`` followed |
| 37 | + by garbage continuation bytes; the resulting wire-level |
| 38 | + ``ProtocolError`` must be wrapped into the client layer's |
| 39 | + ``ProtocolError`` so ``except dqliteclient.ProtocolError`` |
| 40 | + continues to work. |
| 41 | + """ |
| 42 | + import struct |
| 43 | + |
| 44 | + from dqlitewire.constants import ValueType |
| 45 | + from dqlitewire.types import encode_uint64 |
| 46 | + |
| 47 | + # Initial frame: a 1-column RowsResponse with one row, followed |
| 48 | + # by a PART marker to signal more rows in the next frame. |
| 49 | + initial = RowsResponse( |
| 50 | + column_names=["x"], |
| 51 | + column_types=[ValueType.INTEGER], |
| 52 | + rows=[(1,)], |
| 53 | + row_types=[(ValueType.INTEGER,)], |
| 54 | + has_more=False, # encoded below manually |
| 55 | + ).encode() |
| 56 | + # Swap the trailing DONE marker to PART so the decoder expects |
| 57 | + # a continuation frame. |
| 58 | + part_marker = encode_uint64(ROW_PART_MARKER) |
| 59 | + done_marker_prefix = initial[:-8] |
| 60 | + frame_with_part = done_marker_prefix + part_marker |
| 61 | + |
| 62 | + # A continuation frame starts with a header: size_words (uint32), |
| 63 | + # msg_type (uint8), schema (uint8), reserved (uint16), then body. |
| 64 | + # Feed 40 bytes of random garbage as a "continuation frame". |
| 65 | + garbage_header = struct.pack("<IBBH", 1, 0, 0, 0) |
| 66 | + garbage_body = b"\x00" * 8 |
| 67 | + garbage_frame = garbage_header + garbage_body |
| 68 | + |
| 69 | + # Reader returns the valid initial then the garbage in one shot. |
| 70 | + protocol._reader.read = AsyncMock( # type: ignore[attr-defined] |
| 71 | + side_effect=[frame_with_part + garbage_frame, b""] |
| 72 | + ) |
| 73 | + |
| 74 | + protocol._handshake_done = True |
| 75 | + # Kick off a query; any read past the PART marker should trip |
| 76 | + # the wire decoder and surface as client ProtocolError via |
| 77 | + # _read_continuation. |
| 78 | + with pytest.raises(ProtocolError): |
| 79 | + await protocol.query_sql(1, "SELECT 1") |
0 commit comments