Skip to content

Commit 4204053

Browse files
fix: validate column count consistency in RowsResponse.encode_body
RowsResponse.encode_body() now validates that column_types, row_types, and row values all have lengths matching column_names before encoding. Previously, mismatched lengths produced malformed wire data with confusing errors deep in tuples.py or silent misparse on decode. Closes #096 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ee96b81 commit 4204053

2 files changed

Lines changed: 61 additions & 2 deletions

File tree

src/dqlitewire/messages/responses.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ResponseType,
1313
ValueType,
1414
)
15-
from dqlitewire.exceptions import DecodeError
15+
from dqlitewire.exceptions import DecodeError, EncodeError
1616
from dqlitewire.messages.base import Message
1717
from dqlitewire.tuples import (
1818
RowMarker,
@@ -264,7 +264,21 @@ def _get_row_types(self, row_idx: int, row: list[Any]) -> list[ValueType]:
264264
return [encode_value(v)[1] for v in row]
265265

266266
def encode_body(self) -> bytes:
267-
result = encode_uint64(len(self.column_names))
267+
col_count = len(self.column_names)
268+
if self.column_types and len(self.column_types) != col_count:
269+
raise EncodeError(
270+
f"column_types length ({len(self.column_types)}) != "
271+
f"column_names length ({col_count})"
272+
)
273+
for i, row in enumerate(self.rows):
274+
if len(row) != col_count:
275+
raise EncodeError(f"Row {i} has {len(row)} values, expected {col_count}")
276+
if self.row_types and i < len(self.row_types) and len(self.row_types[i]) != col_count:
277+
raise EncodeError(
278+
f"row_types[{i}] has {len(self.row_types[i])} types, expected {col_count}"
279+
)
280+
281+
result = encode_uint64(col_count)
268282

269283
# Column names
270284
for name in self.column_names:

tests/test_messages_responses.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Tests for response message encoding/decoding."""
22

3+
import pytest
4+
35
from dqlitewire.constants import HEADER_SIZE, ResponseType, ValueType
6+
from dqlitewire.exceptions import EncodeError
47
from dqlitewire.messages.base import Header
58
from dqlitewire.messages.responses import (
69
DbResponse,
@@ -576,6 +579,48 @@ def build_body(n_rows: int) -> bytes:
576579
assert len(decoded.rows) == 2
577580

578581

582+
class TestRowsResponseEncodeValidation:
583+
"""Encode-time validation of column count consistency."""
584+
585+
def test_mismatched_column_types_length(self) -> None:
586+
resp = RowsResponse(
587+
column_names=["a", "b", "c"],
588+
column_types=[ValueType.INTEGER],
589+
rows=[],
590+
)
591+
with pytest.raises(EncodeError, match="column_types length"):
592+
resp.encode_body()
593+
594+
def test_mismatched_row_values_length(self) -> None:
595+
resp = RowsResponse(
596+
column_names=["a", "b"],
597+
column_types=[ValueType.INTEGER, ValueType.TEXT],
598+
row_types=[[ValueType.INTEGER, ValueType.TEXT]],
599+
rows=[[1]], # only 1 value, expected 2
600+
)
601+
with pytest.raises(EncodeError, match="Row 0 has 1 values, expected 2"):
602+
resp.encode_body()
603+
604+
def test_mismatched_row_types_length(self) -> None:
605+
resp = RowsResponse(
606+
column_names=["a", "b"],
607+
row_types=[[ValueType.INTEGER]], # only 1 type, expected 2
608+
rows=[[1, 2]],
609+
)
610+
with pytest.raises(EncodeError, match="row_types\\[0\\] has 1 types, expected 2"):
611+
resp.encode_body()
612+
613+
def test_empty_column_types_with_rows_ok(self) -> None:
614+
"""Empty column_types is valid — types are inferred from values."""
615+
resp = RowsResponse(
616+
column_names=["a"],
617+
rows=[[42]],
618+
)
619+
encoded = resp.encode_body()
620+
decoded = RowsResponse.decode_body(encoded)
621+
assert decoded.rows == [[42]]
622+
623+
579624
class TestRowsResponseValueTypes:
580625
"""Full RowsResponse round-trips with BOOLEAN, UNIXTIME, ISO8601, and BLOB."""
581626

0 commit comments

Comments
 (0)