Skip to content

Commit e13c53a

Browse files
Preserve decoded schema on Exec/Query request round-trip
The upstream C client sends schema=1 for ExecRequest / QueryRequest / ExecSqlRequest / QuerySqlRequest regardless of param count, but the Python encoder's count-based heuristic downgraded the schema byte to 0 on re-encode whenever params numbered <= 255. A proxy or mock server that decoded then re-encoded a real client's bytes therefore emitted a different wire shape than it received. Carry the decoded schema on a private dataclass field and honour it on encode so the round-trip stays byte-identical; default construction still uses the count heuristic for callers that do not hold a hint. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ff137e2 commit e13c53a

2 files changed

Lines changed: 102 additions & 4 deletions

File tree

src/dqlitewire/messages/requests.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,20 @@ class ExecRequest(Message):
198198
db_id: int
199199
stmt_id: int
200200
params: Sequence[Any] = field(default_factory=list)
201+
# Preserves the header schema byte seen on decode so a decode →
202+
# re-encode round-trip emits byte-identical output even when the
203+
# upstream C client used schema=1 with ≤255 params (which the count
204+
# heuristic alone would otherwise downgrade to schema=0). Excluded
205+
# from repr/compare so it stays an internal round-trip hint.
206+
_decoded_schema: int | None = field(default=None, repr=False, compare=False)
201207

202208
def __post_init__(self) -> None:
203209
_check_uint32("db_id", self.db_id)
204210
_check_uint32("stmt_id", self.stmt_id)
205211

206212
def _get_schema(self) -> int:
213+
if self._decoded_schema is not None:
214+
return self._decoded_schema
207215
return 1 if len(self.params) > 255 else 0
208216

209217
def encode_body(self) -> bytes:
@@ -217,7 +225,7 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "ExecRequest":
217225
db_id = decode_uint32(data)
218226
stmt_id = decode_uint32(data[4:])
219227
params, _ = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
220-
return cls(db_id, stmt_id, params)
228+
return cls(db_id, stmt_id, params, _decoded_schema=schema)
221229

222230

223231
@dataclass
@@ -233,12 +241,15 @@ class QueryRequest(Message):
233241
db_id: int
234242
stmt_id: int
235243
params: Sequence[Any] = field(default_factory=list)
244+
_decoded_schema: int | None = field(default=None, repr=False, compare=False)
236245

237246
def __post_init__(self) -> None:
238247
_check_uint32("db_id", self.db_id)
239248
_check_uint32("stmt_id", self.stmt_id)
240249

241250
def _get_schema(self) -> int:
251+
if self._decoded_schema is not None:
252+
return self._decoded_schema
242253
return 1 if len(self.params) > 255 else 0
243254

244255
def encode_body(self) -> bytes:
@@ -252,7 +263,7 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "QueryRequest":
252263
db_id = decode_uint32(data)
253264
stmt_id = decode_uint32(data[4:])
254265
params, _ = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
255-
return cls(db_id, stmt_id, params)
266+
return cls(db_id, stmt_id, params, _decoded_schema=schema)
256267

257268

258269
@dataclass
@@ -294,11 +305,14 @@ class ExecSqlRequest(Message):
294305
db_id: int
295306
sql: str
296307
params: Sequence[Any] = field(default_factory=list)
308+
_decoded_schema: int | None = field(default=None, repr=False, compare=False)
297309

298310
def __post_init__(self) -> None:
299311
_check_uint64("db_id", self.db_id)
300312

301313
def _get_schema(self) -> int:
314+
if self._decoded_schema is not None:
315+
return self._decoded_schema
302316
return 1 if len(self.params) > 255 else 0
303317

304318
def encode_body(self) -> bytes:
@@ -314,7 +328,7 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "ExecSqlRequest":
314328
sql, offset = decode_text(data[8:])
315329
offset += 8
316330
params, _ = decode_params_tuple(data[offset:], schema=schema, buffer_offset=offset)
317-
return cls(db_id, sql, params)
331+
return cls(db_id, sql, params, _decoded_schema=schema)
318332

319333

320334
@dataclass
@@ -330,11 +344,14 @@ class QuerySqlRequest(Message):
330344
db_id: int
331345
sql: str
332346
params: Sequence[Any] = field(default_factory=list)
347+
_decoded_schema: int | None = field(default=None, repr=False, compare=False)
333348

334349
def __post_init__(self) -> None:
335350
_check_uint64("db_id", self.db_id)
336351

337352
def _get_schema(self) -> int:
353+
if self._decoded_schema is not None:
354+
return self._decoded_schema
338355
return 1 if len(self.params) > 255 else 0
339356

340357
def encode_body(self) -> bytes:
@@ -350,7 +367,7 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "QuerySqlRequest":
350367
sql, offset = decode_text(data[8:])
351368
offset += 8
352369
params, _ = decode_params_tuple(data[offset:], schema=schema, buffer_offset=offset)
353-
return cls(db_id, sql, params)
370+
return cls(db_id, sql, params, _decoded_schema=schema)
354371

355372

356373
@dataclass

tests/test_messages_requests.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,84 @@ def test_bool_rejected_for_uint64_client_id(self) -> None:
625625

626626
with pytest.raises(TypeError, match="client_id must be int"):
627627
ClientRequest(client_id=True)
628+
629+
630+
class TestParamsBodySchemaRoundtrip:
631+
"""Upstream C clients emit schema=1 for Exec/Query*Request
632+
unconditionally, regardless of param count. Decoding a small-param
633+
schema=1 body and re-encoding must be byte-identical, so a proxy
634+
or mock server that round-trips the bytes faithfully does not
635+
downgrade the schema bit seen over the wire.
636+
"""
637+
638+
@pytest.mark.parametrize(
639+
"cls_name",
640+
["ExecRequest", "QueryRequest"],
641+
)
642+
def test_prepared_schema_1_small_params_roundtrip(self, cls_name: str) -> None:
643+
"""ExecRequest / QueryRequest: 4-byte db_id + 4-byte stmt_id +
644+
PARAMS32 body with 3 params. Construct with schema=1 explicitly,
645+
then verify decode → re-encode == original bytes.
646+
"""
647+
from dqlitewire.codec import encode_message
648+
from dqlitewire.messages import ExecRequest, QueryRequest
649+
650+
classes = {"ExecRequest": ExecRequest, "QueryRequest": QueryRequest}
651+
cls = classes[cls_name]
652+
653+
from dqlitewire.codec import decode_message
654+
655+
# Manually force schema=1 by setting _decoded_schema on the
656+
# source message.
657+
original = cls(db_id=1, stmt_id=2, params=[42, "hello", None], _decoded_schema=1)
658+
original_bytes = encode_message(original)
659+
660+
# Header schema byte is the 6th byte (after 4-byte size_words
661+
# + 1-byte msg_type). Confirm the wire reflects schema=1.
662+
assert original_bytes[5] == 1, "expected schema=1 in the header"
663+
664+
decoded = decode_message(original_bytes, is_request=True)
665+
assert isinstance(decoded, cls)
666+
assert list(decoded.params) == list(original.params)
667+
668+
re_encoded = encode_message(decoded)
669+
assert re_encoded == original_bytes
670+
671+
@pytest.mark.parametrize(
672+
"cls_name",
673+
["ExecSqlRequest", "QuerySqlRequest"],
674+
)
675+
def test_sql_schema_1_small_params_roundtrip(self, cls_name: str) -> None:
676+
from dqlitewire.codec import decode_message, encode_message
677+
from dqlitewire.messages import ExecSqlRequest, QuerySqlRequest
678+
679+
classes = {"ExecSqlRequest": ExecSqlRequest, "QuerySqlRequest": QuerySqlRequest}
680+
cls = classes[cls_name]
681+
682+
original = cls(db_id=1, sql="SELECT 1", params=[1, 2, 3], _decoded_schema=1)
683+
original_bytes = encode_message(original)
684+
assert original_bytes[5] == 1
685+
686+
decoded = decode_message(original_bytes, is_request=True)
687+
assert isinstance(decoded, cls)
688+
assert list(decoded.params) == list(original.params)
689+
assert decoded.sql == "SELECT 1"
690+
691+
re_encoded = encode_message(decoded)
692+
assert re_encoded == original_bytes
693+
694+
def test_default_construction_uses_heuristic(self) -> None:
695+
"""When a caller constructs a fresh ExecRequest without a
696+
decoded schema hint, the count heuristic still applies: ≤255
697+
params → schema=0.
698+
"""
699+
from dqlitewire.messages import ExecRequest
700+
701+
msg = ExecRequest(db_id=1, stmt_id=2, params=[1, 2, 3])
702+
assert msg._get_schema() == 0
703+
704+
def test_large_params_force_schema_1_without_hint(self) -> None:
705+
from dqlitewire.messages import ExecRequest
706+
707+
msg = ExecRequest(db_id=1, stmt_id=2, params=list(range(300)))
708+
assert msg._get_schema() == 1

0 commit comments

Comments
 (0)