|
14 | 14 | import pytest |
15 | 15 |
|
16 | 16 | from dqlitewire.exceptions import DecodeError |
17 | | -from dqlitewire.messages.responses import DbResponse, EmptyResponse, ResultResponse |
| 17 | +from dqlitewire.messages.responses import DbResponse, EmptyResponse, ResultResponse, StmtResponse |
18 | 18 |
|
19 | 19 |
|
20 | 20 | class TestEmptyResponseStrictLength: |
@@ -78,3 +78,61 @@ def test_exact_length_accepted(self) -> None: |
78 | 78 | def test_trailing_bytes_rejected(self) -> None: |
79 | 79 | with pytest.raises(DecodeError, match="ResultResponse body must be exactly 16 bytes"): |
80 | 80 | ResultResponse.decode_body(b"\x00" * 17) |
| 81 | + |
| 82 | + |
| 83 | +class TestStmtResponseStrictLength: |
| 84 | + """StmtResponse carries db_id + stmt_id + num_params (+ optional |
| 85 | + tail_offset for schema>=1). Body is exactly 16 bytes (schema 0) or |
| 86 | + 24 bytes (schema 1). Trailing bytes had been silently accepted; |
| 87 | + sibling responses reject them. |
| 88 | + """ |
| 89 | + |
| 90 | + @staticmethod |
| 91 | + def _body_schema0(db_id: int = 1, stmt_id: int = 42, num_params: int = 3) -> bytes: |
| 92 | + return ( |
| 93 | + db_id.to_bytes(4, "little") |
| 94 | + + stmt_id.to_bytes(4, "little") |
| 95 | + + num_params.to_bytes(8, "little") |
| 96 | + ) |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def _body_schema1( |
| 100 | + db_id: int = 1, stmt_id: int = 42, num_params: int = 3, tail_offset: int = 0 |
| 101 | + ) -> bytes: |
| 102 | + return TestStmtResponseStrictLength._body_schema0( |
| 103 | + db_id, stmt_id, num_params |
| 104 | + ) + tail_offset.to_bytes(8, "little") |
| 105 | + |
| 106 | + def test_short_body_rejected_schema0(self) -> None: |
| 107 | + with pytest.raises(DecodeError, match=r"StmtResponse schema=0 body must be exactly 16"): |
| 108 | + StmtResponse.decode_body(b"\x00" * 15, schema=0) |
| 109 | + |
| 110 | + def test_exact_length_accepted_schema0(self) -> None: |
| 111 | + msg = StmtResponse.decode_body(self._body_schema0(), schema=0) |
| 112 | + assert msg.db_id == 1 |
| 113 | + assert msg.stmt_id == 42 |
| 114 | + assert msg.num_params == 3 |
| 115 | + assert msg.tail_offset is None |
| 116 | + |
| 117 | + def test_trailing_bytes_rejected_schema0(self) -> None: |
| 118 | + """Previous decoder silently accepted extra bytes; must now |
| 119 | + raise so a conforming StmtResponse round-trips exactly.""" |
| 120 | + body = self._body_schema0() + b"\x01" |
| 121 | + with pytest.raises(DecodeError, match=r"StmtResponse schema=0 body must be exactly 16"): |
| 122 | + StmtResponse.decode_body(body, schema=0) |
| 123 | + |
| 124 | + def test_short_body_rejected_schema1(self) -> None: |
| 125 | + with pytest.raises(DecodeError, match=r"StmtResponse schema=1 body must be exactly 24"): |
| 126 | + StmtResponse.decode_body(b"\x00" * 23, schema=1) |
| 127 | + |
| 128 | + def test_exact_length_accepted_schema1(self) -> None: |
| 129 | + msg = StmtResponse.decode_body(self._body_schema1(tail_offset=7), schema=1) |
| 130 | + assert msg.db_id == 1 |
| 131 | + assert msg.stmt_id == 42 |
| 132 | + assert msg.num_params == 3 |
| 133 | + assert msg.tail_offset == 7 |
| 134 | + |
| 135 | + def test_trailing_bytes_rejected_schema1(self) -> None: |
| 136 | + body = self._body_schema1() + b"\x02" |
| 137 | + with pytest.raises(DecodeError, match=r"StmtResponse schema=1 body must be exactly 24"): |
| 138 | + StmtResponse.decode_body(body, schema=1) |
0 commit comments