Skip to content

Commit 51f7850

Browse files
Reject trailing bytes in request and response decoders
Upstream C drives every decode through a struct cursor with explicit cap so extra bytes past the declared fields are detected as DQLITE_PARSE. Prior cycles applied the strict-length pattern to EmptyResponse / DbResponse / ResultResponse and LeaderRequest; this change extends it to: - every fixed-size request decoder (Client, Heartbeat, Interrupt, Remove, Transfer, Weight, Finalize, Cluster) - every variable-size request decoder (Open, Prepare, Exec, Query, ExecSql, QuerySql, Connect, Add, Dump) via an explicit offset-equals-length check after the last field - AssignRequest, whose ``len(data) >= 16`` threshold silently dropped any trailing bytes past offset 16; now equality-checked against exactly 8 (Promote) or 16 (Assign) - DescribeRequest's format field: upstream defines only V0=0 and rejects anything else with SQLITE_PROTOCOL; reject client-side at construction and decode - FilesResponse, whose loop returned without verifying buffer exhaustion on the last file - RowsResponse zero-column fast path, which read the 8-byte DONE/ PART marker and returned without checking buffer exhaustion Decoder round-trips remain byte-identical; the change only rejects previously-silently-accepted corrupt or extended bodies.
1 parent 216426b commit 51f7850

3 files changed

Lines changed: 294 additions & 13 deletions

File tree

src/dqlitewire/messages/requests.py

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def encode_body(self) -> bytes:
7777

7878
@classmethod
7979
def decode_body(cls, data: bytes, schema: int = 0) -> "ClientRequest":
80+
if len(data) != 8:
81+
raise DecodeError(f"ClientRequest body must be 8 bytes, got {len(data)}")
8082
client_id = decode_uint64(data)
8183
return cls(client_id)
8284

@@ -108,6 +110,8 @@ def encode_body(self) -> bytes:
108110

109111
@classmethod
110112
def decode_body(cls, data: bytes, schema: int = 0) -> "HeartbeatRequest":
113+
if len(data) != 8:
114+
raise DecodeError(f"HeartbeatRequest body must be 8 bytes, got {len(data)}")
111115
timestamp = decode_uint64(data)
112116
return cls(timestamp)
113117

@@ -145,7 +149,10 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "OpenRequest":
145149
name, offset = decode_text(data)
146150
flags = decode_uint64(data[offset:])
147151
offset += 8
148-
vfs, _ = decode_text(data[offset:])
152+
vfs, consumed = decode_text(data[offset:])
153+
offset += consumed
154+
if offset != len(data):
155+
raise DecodeError(f"OpenRequest has {len(data) - offset} trailing bytes")
149156
return cls(name, flags, vfs)
150157

151158

@@ -181,7 +188,10 @@ def encode_body(self) -> bytes:
181188
@classmethod
182189
def decode_body(cls, data: bytes, schema: int = 0) -> "PrepareRequest":
183190
db_id = decode_uint64(data)
184-
sql, _ = decode_text(data[8:])
191+
sql, consumed = decode_text(data[8:])
192+
offset = 8 + consumed
193+
if offset != len(data):
194+
raise DecodeError(f"PrepareRequest has {len(data) - offset} trailing bytes")
185195
return cls(db_id, sql, schema=schema)
186196

187197

@@ -224,7 +234,10 @@ def encode_body(self) -> bytes:
224234
def decode_body(cls, data: bytes, schema: int = 0) -> "ExecRequest":
225235
db_id = decode_uint32(data)
226236
stmt_id = decode_uint32(data[4:])
227-
params, _ = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
237+
params, consumed = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
238+
offset = 8 + consumed
239+
if offset != len(data):
240+
raise DecodeError(f"ExecRequest has {len(data) - offset} trailing bytes")
228241
return cls(db_id, stmt_id, params, _decoded_schema=schema)
229242

230243

@@ -262,7 +275,10 @@ def encode_body(self) -> bytes:
262275
def decode_body(cls, data: bytes, schema: int = 0) -> "QueryRequest":
263276
db_id = decode_uint32(data)
264277
stmt_id = decode_uint32(data[4:])
265-
params, _ = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
278+
params, consumed = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
279+
offset = 8 + consumed
280+
if offset != len(data):
281+
raise DecodeError(f"QueryRequest has {len(data) - offset} trailing bytes")
266282
return cls(db_id, stmt_id, params, _decoded_schema=schema)
267283

268284

@@ -287,6 +303,8 @@ def encode_body(self) -> bytes:
287303

288304
@classmethod
289305
def decode_body(cls, data: bytes, schema: int = 0) -> "FinalizeRequest":
306+
if len(data) != 8:
307+
raise DecodeError(f"FinalizeRequest body must be 8 bytes, got {len(data)}")
290308
db_id = decode_uint32(data)
291309
stmt_id = decode_uint32(data[4:])
292310
return cls(db_id, stmt_id)
@@ -327,7 +345,10 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "ExecSqlRequest":
327345
db_id = decode_uint64(data)
328346
sql, offset = decode_text(data[8:])
329347
offset += 8
330-
params, _ = decode_params_tuple(data[offset:], schema=schema, buffer_offset=offset)
348+
params, consumed = decode_params_tuple(data[offset:], schema=schema, buffer_offset=offset)
349+
offset += consumed
350+
if offset != len(data):
351+
raise DecodeError(f"ExecSqlRequest has {len(data) - offset} trailing bytes")
331352
return cls(db_id, sql, params, _decoded_schema=schema)
332353

333354

@@ -366,7 +387,10 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "QuerySqlRequest":
366387
db_id = decode_uint64(data)
367388
sql, offset = decode_text(data[8:])
368389
offset += 8
369-
params, _ = decode_params_tuple(data[offset:], schema=schema, buffer_offset=offset)
390+
params, consumed = decode_params_tuple(data[offset:], schema=schema, buffer_offset=offset)
391+
offset += consumed
392+
if offset != len(data):
393+
raise DecodeError(f"QuerySqlRequest has {len(data) - offset} trailing bytes")
370394
return cls(db_id, sql, params, _decoded_schema=schema)
371395

372396

@@ -389,6 +413,8 @@ def encode_body(self) -> bytes:
389413

390414
@classmethod
391415
def decode_body(cls, data: bytes, schema: int = 0) -> "InterruptRequest":
416+
if len(data) != 8:
417+
raise DecodeError(f"InterruptRequest body must be 8 bytes, got {len(data)}")
392418
db_id = decode_uint64(data)
393419
return cls(db_id)
394420

@@ -419,7 +445,10 @@ def encode_body(self) -> bytes:
419445
@classmethod
420446
def decode_body(cls, data: bytes, schema: int = 0) -> "ConnectRequest":
421447
node_id = decode_uint64(data)
422-
address, _ = decode_text(data[8:])
448+
address, consumed = decode_text(data[8:])
449+
offset = 8 + consumed
450+
if offset != len(data):
451+
raise DecodeError(f"ConnectRequest has {len(data) - offset} trailing bytes")
423452
return cls(node_id, address)
424453

425454

@@ -444,7 +473,10 @@ def encode_body(self) -> bytes:
444473
@classmethod
445474
def decode_body(cls, data: bytes, schema: int = 0) -> "AddRequest":
446475
node_id = decode_uint64(data)
447-
address, _ = decode_text(data[8:])
476+
address, consumed = decode_text(data[8:])
477+
offset = 8 + consumed
478+
if offset != len(data):
479+
raise DecodeError(f"AddRequest has {len(data) - offset} trailing bytes")
448480
return cls(node_id, address)
449481

450482

@@ -485,11 +517,19 @@ def encode_body(self) -> bytes:
485517

486518
@classmethod
487519
def decode_body(cls, data: bytes, schema: int = 0) -> "AssignRequest":
488-
node_id = decode_uint64(data)
489-
role: int | None = None
490-
if len(data) >= 16:
520+
# Upstream emits bodies of exactly 8 (PROMOTE) or 16 (ASSIGN)
521+
# bytes. Reject anything else rather than silently dropping
522+
# trailing bytes — parity with the C cursor-cap semantics.
523+
if len(data) == 8:
524+
node_id = decode_uint64(data)
525+
return cls(node_id, None)
526+
if len(data) == 16:
527+
node_id = decode_uint64(data)
491528
role = decode_uint64(data[8:])
492-
return cls(node_id, role)
529+
return cls(node_id, role)
530+
raise DecodeError(
531+
f"AssignRequest body must be 8 (PROMOTE) or 16 (ASSIGN) bytes, got {len(data)}"
532+
)
493533

494534

495535
@dataclass
@@ -511,6 +551,8 @@ def encode_body(self) -> bytes:
511551

512552
@classmethod
513553
def decode_body(cls, data: bytes, schema: int = 0) -> "RemoveRequest":
554+
if len(data) != 8:
555+
raise DecodeError(f"RemoveRequest body must be 8 bytes, got {len(data)}")
514556
node_id = decode_uint64(data)
515557
return cls(node_id)
516558

@@ -531,7 +573,9 @@ def encode_body(self) -> bytes:
531573

532574
@classmethod
533575
def decode_body(cls, data: bytes, schema: int = 0) -> "DumpRequest":
534-
name, _ = decode_text(data)
576+
name, consumed = decode_text(data)
577+
if consumed != len(data):
578+
raise DecodeError(f"DumpRequest has {len(data) - consumed} trailing bytes")
535579
return cls(name)
536580

537581

@@ -566,6 +610,8 @@ def encode_body(self) -> bytes:
566610

567611
@classmethod
568612
def decode_body(cls, data: bytes, schema: int = 0) -> "ClusterRequest":
613+
if len(data) != 8:
614+
raise DecodeError(f"ClusterRequest body must be 8 bytes, got {len(data)}")
569615
format_val = decode_uint64(data)
570616
if format_val == 0:
571617
raise DecodeError(
@@ -595,6 +641,8 @@ def encode_body(self) -> bytes:
595641

596642
@classmethod
597643
def decode_body(cls, data: bytes, schema: int = 0) -> "TransferRequest":
644+
if len(data) != 8:
645+
raise DecodeError(f"TransferRequest body must be 8 bytes, got {len(data)}")
598646
target_node_id = decode_uint64(data)
599647
return cls(target_node_id)
600648

@@ -604,6 +652,11 @@ class DescribeRequest(Message):
604652
"""Request database schema description.
605653
606654
Body: uint64 format
655+
656+
Upstream defines only ``DQLITE_REQUEST_DESCRIBE_FORMAT_V0 = 0``
657+
(``gateway.c`` rejects anything else with ``SQLITE_PROTOCOL``).
658+
Reject unknown formats client-side so callers get a local
659+
``ValueError`` instead of a confusing server failure.
607660
"""
608661

609662
MSG_TYPE: ClassVar[int] = RequestType.DESCRIBE
@@ -612,13 +665,22 @@ class DescribeRequest(Message):
612665

613666
def __post_init__(self) -> None:
614667
_check_uint64("format", self.format)
668+
if self.format != 0:
669+
raise ValueError(
670+
f"DescribeRequest format must be 0 (V0); upstream rejects "
671+
f"anything else with SQLITE_PROTOCOL. Got {self.format}."
672+
)
615673

616674
def encode_body(self) -> bytes:
617675
return encode_uint64(self.format)
618676

619677
@classmethod
620678
def decode_body(cls, data: bytes, schema: int = 0) -> "DescribeRequest":
679+
if len(data) != 8:
680+
raise DecodeError(f"DescribeRequest body must be 8 bytes, got {len(data)}")
621681
format_val = decode_uint64(data)
682+
if format_val != 0:
683+
raise DecodeError(f"DescribeRequest format must be 0 (V0); got {format_val}")
622684
return cls(format_val)
623685

624686

@@ -641,5 +703,7 @@ def encode_body(self) -> bytes:
641703

642704
@classmethod
643705
def decode_body(cls, data: bytes, schema: int = 0) -> "WeightRequest":
706+
if len(data) != 8:
707+
raise DecodeError(f"WeightRequest body must be 8 bytes, got {len(data)}")
644708
weight = decode_uint64(data)
645709
return cls(weight)

src/dqlitewire/messages/responses.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,15 @@ def decode_body(
495495
raise DecodeError(
496496
f"Expected DONE or PART marker for zero-column result, got 0x{marker.hex()}"
497497
)
498+
# The zero-column fast path is Python-specific (upstream C never
499+
# emits zero-column result sets); enforce buffer exhaustion to
500+
# match the strict-decode pattern used by every sibling decoder.
501+
end = offset + WORD_SIZE
502+
if end != len(view):
503+
raise DecodeError(
504+
f"RowsResponse zero-column body has {len(view) - end} "
505+
"trailing bytes after DONE/PART marker"
506+
)
498507
return cls(
499508
column_names=[],
500509
column_types=[],
@@ -648,6 +657,14 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "FilesResponse":
648657
# No padding after content — matches Go's byte-by-byte read.
649658
offset += size
650659
files[name] = content
660+
# Upstream client enforces `cursor.cap == fs[i].size` at each
661+
# iteration; on the last file that amounts to "body must be
662+
# exhausted." Mirror the strictness so corrupt / malicious
663+
# trailing bytes cannot vanish silently.
664+
if offset != len(view):
665+
raise DecodeError(
666+
f"FilesResponse has {len(view) - offset} trailing bytes after last file"
667+
)
651668
return cls(files)
652669

653670

0 commit comments

Comments
 (0)