Skip to content

Commit cac8c81

Browse files
chore: batch code cleanup (issues 068-080)
- 068: Use encode_uint64/decode_uint64 for BOOLEAN (matches C uint64_t) - 069: Fix Header docstring: schema is 0 or 1, not "always 0" - 070: Remove stale line-number reference in types.py comment - 071: Change zip(strict=False) to strict=True in encode_row_values - 072: Make _SUPPORTED_VERSIONS a frozenset - 074: Extract _check_torn_size helper, deduplicate 3 copies - 076: Move ~26 lazy imports to top-level in requests.py/responses.py - 077: Extract _COMPACT_THRESHOLD named constant (was magic 4096) - 078: Sort __all__ alphabetically in __init__.py - 080: Add [dependency-groups] dev to pyproject.toml Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1201aef commit cac8c81

9 files changed

Lines changed: 60 additions & 120 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Issues = "https://github.com/letsdiscodev/python-dqlite-wire/issues"
3131
[project.optional-dependencies]
3232
dev = ["pytest>=8.0", "pytest-cov>=4.0", "mypy>=1.0", "ruff>=0.4"]
3333

34+
[dependency-groups]
35+
dev = ["pytest>=8.0", "pytest-cov>=4.0", "mypy>=1.0", "ruff>=0.4"]
36+
3437
[tool.hatch.build.targets.wheel]
3538
packages = ["src/dqlitewire"]
3639

src/dqlitewire/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,22 @@
5656
from dqlitewire.exceptions import DecodeError, EncodeError, ProtocolError
5757

5858
__all__ = [
59+
"DecodeError",
60+
"EncodeError",
5961
"MessageDecoder",
6062
"MessageEncoder",
6163
"PROTOCOL_VERSION",
6264
"PROTOCOL_VERSION_LEGACY",
6365
"ProtocolError",
64-
"DecodeError",
65-
"EncodeError",
66-
"ReadBuffer",
6766
"ROW_DONE_BYTE",
6867
"ROW_DONE_MARKER",
6968
"ROW_PART_BYTE",
7069
"ROW_PART_MARKER",
71-
"WriteBuffer",
70+
"ReadBuffer",
7271
"RequestType",
7372
"ResponseType",
7473
"ValueType",
74+
"WriteBuffer",
7575
]
7676

7777
__version__ = "0.1.0"

src/dqlitewire/buffer.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dqlitewire.constants import HEADER_SIZE, WORD_SIZE
44
from dqlitewire.exceptions import DecodeError, ProtocolError
55

6+
_COMPACT_THRESHOLD = 4096
7+
68

79
class WriteBuffer:
810
"""Buffer for building wire protocol messages.
@@ -122,6 +124,22 @@ def _check_poisoned(self) -> None:
122124
"buffer is poisoned; call reset() and reconnect"
123125
) from self._poisoned
124126

127+
def _check_torn_size(self, size_words: int) -> None:
128+
"""Poison and raise if size_words is structurally impossible.
129+
130+
The wire size field is 4 bytes (uint32 LE), so any value
131+
> 0xFFFFFFFF indicates a torn read from a concurrent realloc
132+
on a free-threaded build. Distinguish from legitimate oversized
133+
messages so the non-poisoning DecodeError contract is preserved.
134+
"""
135+
if size_words > 0xFFFFFFFF:
136+
err = DecodeError(
137+
f"torn header read: size_words={size_words:#x} (>32 bits, "
138+
"indicates concurrent misuse on a free-threaded build)"
139+
)
140+
self.poison(err)
141+
raise err
142+
125143
def feed(self, data: bytes) -> None:
126144
"""Add received data to the buffer.
127145
@@ -239,23 +257,7 @@ def peek_header(self) -> tuple[int, int, int] | None:
239257
return None
240258

241259
size_words = int.from_bytes(self._data[self._pos : self._pos + 4], "little")
242-
# Sanity check for torn reads (issue 051). The wire size
243-
# field is exactly 4 bytes (uint32 little-endian), so any
244-
# value > 0xFFFFFFFF cannot come from a well-formed header
245-
# — it can only come from a ``bytearray`` slice that
246-
# observed torn ``ob_size``/``ob_start`` during a concurrent
247-
# realloc on a free-threaded build, returning more than 4
248-
# bytes. Distinguish this from legitimate oversized messages
249-
# so the non-poisoning ``DecodeError`` recovery contract
250-
# still applies to real wire-oversized messages while torn
251-
# reads poison the buffer.
252-
if size_words > 0xFFFFFFFF:
253-
err = DecodeError(
254-
f"torn header read: size_words={size_words:#x} (>32 bits, "
255-
"indicates concurrent misuse on a free-threaded build)"
256-
)
257-
self.poison(err)
258-
raise err
260+
self._check_torn_size(size_words)
259261
total_size = HEADER_SIZE + (size_words * WORD_SIZE)
260262
if total_size > self._max_message_size:
261263
# Format size in hex: under concurrent misuse (see issue 033)
@@ -285,23 +287,7 @@ def read_message(self) -> bytes | None:
285287
return None
286288

287289
size_words = int.from_bytes(self._data[self._pos : self._pos + 4], "little")
288-
# Sanity check for torn reads (issue 051). The wire size
289-
# field is exactly 4 bytes (uint32 little-endian), so any
290-
# value > 0xFFFFFFFF cannot come from a well-formed header
291-
# — it can only come from a ``bytearray`` slice that
292-
# observed torn ``ob_size``/``ob_start`` during a concurrent
293-
# realloc on a free-threaded build, returning more than 4
294-
# bytes. Distinguish this from legitimate oversized messages
295-
# so the non-poisoning ``DecodeError`` recovery contract
296-
# still applies to real wire-oversized messages while torn
297-
# reads poison the buffer.
298-
if size_words > 0xFFFFFFFF:
299-
err = DecodeError(
300-
f"torn header read: size_words={size_words:#x} (>32 bits, "
301-
"indicates concurrent misuse on a free-threaded build)"
302-
)
303-
self.poison(err)
304-
raise err
290+
self._check_torn_size(size_words)
305291
total_size = HEADER_SIZE + (size_words * WORD_SIZE)
306292

307293
if total_size > self._max_message_size:
@@ -351,23 +337,7 @@ def skip_message(self) -> bool:
351337
return False
352338

353339
size_words = int.from_bytes(self._data[self._pos : self._pos + 4], "little")
354-
# Sanity check for torn reads (issue 051). The wire size
355-
# field is exactly 4 bytes (uint32 little-endian), so any
356-
# value > 0xFFFFFFFF cannot come from a well-formed header
357-
# — it can only come from a ``bytearray`` slice that
358-
# observed torn ``ob_size``/``ob_start`` during a concurrent
359-
# realloc on a free-threaded build, returning more than 4
360-
# bytes. Distinguish this from legitimate oversized messages
361-
# so the non-poisoning ``DecodeError`` recovery contract
362-
# still applies to real wire-oversized messages while torn
363-
# reads poison the buffer.
364-
if size_words > 0xFFFFFFFF:
365-
err = DecodeError(
366-
f"torn header read: size_words={size_words:#x} (>32 bits, "
367-
"indicates concurrent misuse on a free-threaded build)"
368-
)
369-
self.poison(err)
370-
raise err
340+
self._check_torn_size(size_words)
371341
total_size = HEADER_SIZE + (size_words * WORD_SIZE)
372342

373343
# Normal-sized message: only skip when complete to avoid
@@ -459,7 +429,7 @@ def _maybe_compact(self) -> None:
459429
caller fails fast with ``ProtocolError`` instead of reading
460430
from an inconsistent offset.
461431
"""
462-
if self._pos <= 4096:
432+
if self._pos <= _COMPACT_THRESHOLD:
463433
return
464434
try:
465435
new_data = self._data[self._pos :]

src/dqlitewire/codec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
}
9797

9898

99-
_SUPPORTED_VERSIONS = {PROTOCOL_VERSION, PROTOCOL_VERSION_LEGACY}
99+
_SUPPORTED_VERSIONS = frozenset({PROTOCOL_VERSION, PROTOCOL_VERSION_LEGACY})
100100

101101

102102
class MessageEncoder:

src/dqlitewire/messages/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Header:
1616
Format (8 bytes):
1717
- size: uint32 - Size of message body in words (8-byte units)
1818
- type: uint8 - Message type code
19-
- schema: uint8 - Schema version (currently always 0)
19+
- schema: uint8 - Schema version (0 or 1; V1 extends param tuples and StmtResponse)
2020
- reserved: uint16 - Reserved (always 0)
2121
"""
2222

src/dqlitewire/messages/requests.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,15 @@
66

77
from dqlitewire.constants import RequestType
88
from dqlitewire.messages.base import Message
9-
from dqlitewire.tuples import encode_params_tuple
10-
from dqlitewire.types import encode_text, encode_uint32, encode_uint64
9+
from dqlitewire.tuples import decode_params_tuple, encode_params_tuple
10+
from dqlitewire.types import (
11+
decode_text,
12+
decode_uint32,
13+
decode_uint64,
14+
encode_text,
15+
encode_uint32,
16+
encode_uint64,
17+
)
1118

1219

1320
def _check_uint32(name: str, value: int) -> None:
@@ -60,8 +67,6 @@ def encode_body(self) -> bytes:
6067

6168
@classmethod
6269
def decode_body(cls, data: bytes, schema: int = 0) -> "ClientRequest":
63-
from dqlitewire.types import decode_uint64
64-
6570
client_id = decode_uint64(data)
6671
return cls(client_id)
6772

@@ -85,8 +90,6 @@ def encode_body(self) -> bytes:
8590

8691
@classmethod
8792
def decode_body(cls, data: bytes, schema: int = 0) -> "HeartbeatRequest":
88-
from dqlitewire.types import decode_uint64
89-
9093
timestamp = decode_uint64(data)
9194
return cls(timestamp)
9295

@@ -115,8 +118,6 @@ def encode_body(self) -> bytes:
115118

116119
@classmethod
117120
def decode_body(cls, data: bytes, schema: int = 0) -> "OpenRequest":
118-
from dqlitewire.types import decode_text, decode_uint64
119-
120121
name, offset = decode_text(data)
121122
flags = decode_uint64(data[offset:])
122123
offset += 8
@@ -155,8 +156,6 @@ def encode_body(self) -> bytes:
155156

156157
@classmethod
157158
def decode_body(cls, data: bytes, schema: int = 0) -> "PrepareRequest":
158-
from dqlitewire.types import decode_text, decode_uint64
159-
160159
db_id = decode_uint64(data)
161160
sql, _ = decode_text(data[8:])
162161
return cls(db_id, sql, schema=schema)
@@ -191,9 +190,6 @@ def encode_body(self) -> bytes:
191190

192191
@classmethod
193192
def decode_body(cls, data: bytes, schema: int = 0) -> "ExecRequest":
194-
from dqlitewire.tuples import decode_params_tuple
195-
from dqlitewire.types import decode_uint32
196-
197193
db_id = decode_uint32(data)
198194
stmt_id = decode_uint32(data[4:])
199195
params, _ = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
@@ -229,9 +225,6 @@ def encode_body(self) -> bytes:
229225

230226
@classmethod
231227
def decode_body(cls, data: bytes, schema: int = 0) -> "QueryRequest":
232-
from dqlitewire.tuples import decode_params_tuple
233-
from dqlitewire.types import decode_uint32
234-
235228
db_id = decode_uint32(data)
236229
stmt_id = decode_uint32(data[4:])
237230
params, _ = decode_params_tuple(data[8:], schema=schema, buffer_offset=8)
@@ -259,8 +252,6 @@ def encode_body(self) -> bytes:
259252

260253
@classmethod
261254
def decode_body(cls, data: bytes, schema: int = 0) -> "FinalizeRequest":
262-
from dqlitewire.types import decode_uint32
263-
264255
db_id = decode_uint32(data)
265256
stmt_id = decode_uint32(data[4:])
266257
return cls(db_id, stmt_id)
@@ -295,9 +286,6 @@ def encode_body(self) -> bytes:
295286

296287
@classmethod
297288
def decode_body(cls, data: bytes, schema: int = 0) -> "ExecSqlRequest":
298-
from dqlitewire.tuples import decode_params_tuple
299-
from dqlitewire.types import decode_text, decode_uint64
300-
301289
db_id = decode_uint64(data)
302290
sql, offset = decode_text(data[8:])
303291
offset += 8
@@ -334,9 +322,6 @@ def encode_body(self) -> bytes:
334322

335323
@classmethod
336324
def decode_body(cls, data: bytes, schema: int = 0) -> "QuerySqlRequest":
337-
from dqlitewire.tuples import decode_params_tuple
338-
from dqlitewire.types import decode_text, decode_uint64
339-
340325
db_id = decode_uint64(data)
341326
sql, offset = decode_text(data[8:])
342327
offset += 8
@@ -363,8 +348,6 @@ def encode_body(self) -> bytes:
363348

364349
@classmethod
365350
def decode_body(cls, data: bytes, schema: int = 0) -> "InterruptRequest":
366-
from dqlitewire.types import decode_uint64
367-
368351
db_id = decode_uint64(data)
369352
return cls(db_id)
370353

@@ -394,8 +377,6 @@ def encode_body(self) -> bytes:
394377

395378
@classmethod
396379
def decode_body(cls, data: bytes, schema: int = 0) -> "ConnectRequest":
397-
from dqlitewire.types import decode_text, decode_uint64
398-
399380
node_id = decode_uint64(data)
400381
address, _ = decode_text(data[8:])
401382
return cls(node_id, address)
@@ -421,8 +402,6 @@ def encode_body(self) -> bytes:
421402

422403
@classmethod
423404
def decode_body(cls, data: bytes, schema: int = 0) -> "AddRequest":
424-
from dqlitewire.types import decode_text, decode_uint64
425-
426405
node_id = decode_uint64(data)
427406
address, _ = decode_text(data[8:])
428407
return cls(node_id, address)
@@ -465,8 +444,6 @@ def encode_body(self) -> bytes:
465444

466445
@classmethod
467446
def decode_body(cls, data: bytes, schema: int = 0) -> "AssignRequest":
468-
from dqlitewire.types import decode_uint64
469-
470447
node_id = decode_uint64(data)
471448
role: int | None = None
472449
if len(data) >= 16:
@@ -493,8 +470,6 @@ def encode_body(self) -> bytes:
493470

494471
@classmethod
495472
def decode_body(cls, data: bytes, schema: int = 0) -> "RemoveRequest":
496-
from dqlitewire.types import decode_uint64
497-
498473
node_id = decode_uint64(data)
499474
return cls(node_id)
500475

@@ -515,8 +490,6 @@ def encode_body(self) -> bytes:
515490

516491
@classmethod
517492
def decode_body(cls, data: bytes, schema: int = 0) -> "DumpRequest":
518-
from dqlitewire.types import decode_text
519-
520493
name, _ = decode_text(data)
521494
return cls(name)
522495

@@ -540,8 +513,6 @@ def encode_body(self) -> bytes:
540513

541514
@classmethod
542515
def decode_body(cls, data: bytes, schema: int = 0) -> "ClusterRequest":
543-
from dqlitewire.types import decode_uint64
544-
545516
format_val = decode_uint64(data)
546517
return cls(format_val)
547518

@@ -565,8 +536,6 @@ def encode_body(self) -> bytes:
565536

566537
@classmethod
567538
def decode_body(cls, data: bytes, schema: int = 0) -> "TransferRequest":
568-
from dqlitewire.types import decode_uint64
569-
570539
target_node_id = decode_uint64(data)
571540
return cls(target_node_id)
572541

@@ -590,8 +559,6 @@ def encode_body(self) -> bytes:
590559

591560
@classmethod
592561
def decode_body(cls, data: bytes, schema: int = 0) -> "DescribeRequest":
593-
from dqlitewire.types import decode_uint64
594-
595562
format_val = decode_uint64(data)
596563
return cls(format_val)
597564

@@ -615,7 +582,5 @@ def encode_body(self) -> bytes:
615582

616583
@classmethod
617584
def decode_body(cls, data: bytes, schema: int = 0) -> "WeightRequest":
618-
from dqlitewire.types import decode_uint64
619-
620585
weight = decode_uint64(data)
621586
return cls(weight)

0 commit comments

Comments
 (0)