Skip to content

Commit 93382b2

Browse files
Narrow encode_value / decode_value types and accept bytes-like in inference
Introduce two PEP 695 type aliases at the top of types.py — WireInput (bool|int|float|str|bytes|bytearray|memoryview|None) for the encoder input set, and WireValue (bool|int|float|str|bytes|None) for the decoder output set — and use them to narrow the previously-Any signatures of encode_value and decode_value. The encoder's type inference now maps bytearray and memoryview to BLOB the way the explicit BLOB branch already did (and the way stdlib sqlite3 does), so zero-copy or mutation-built payloads no longer hit EncodeError on pure inference paths. The decoder keeps producing plain bytes, so WireValue is intentionally narrower than WireInput. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e1b4428 commit 93382b2

2 files changed

Lines changed: 103 additions & 7 deletions

File tree

src/dqlitewire/types.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,24 @@
1111
"""
1212

1313
import struct
14-
from typing import Any
1514

1615
from dqlitewire.constants import WORD_SIZE, ValueType
1716
from dqlitewire.exceptions import DecodeError, EncodeError
1817

18+
# Exact set of Python types ``encode_value`` accepts. Callers see a
19+
# type-checker error if they pass something else, instead of a runtime
20+
# EncodeError at the first wire round-trip. ``bytes``-like siblings
21+
# (``bytearray``, ``memoryview``) are accepted because the BLOB encoder
22+
# normalises them through ``bytes(value)``; the inference path maps
23+
# them to ValueType.BLOB for the same reason stdlib ``sqlite3`` does.
24+
type WireInput = bool | int | float | str | bytes | bytearray | memoryview | None
25+
26+
# Exact set of Python types ``decode_value`` may return (first element
27+
# of the ``(value, consumed)`` tuple). Narrower than ``Any`` — wire
28+
# values are always one of these primitives, and the driver layer
29+
# widens to ``Any`` only at the PEP 249 row-tuple boundary.
30+
type WireValue = bool | int | float | str | bytes | None
31+
1932
# Per-BLOB byte cap. The overall frame-size cap in ``buffer.py`` (64 MiB)
2033
# already bounds any single message, but a hostile or buggy peer can
2134
# otherwise pack a single BLOB field that consumes the whole frame. The
@@ -258,7 +271,7 @@ def decode_blob(data: bytes | memoryview) -> tuple[bytes, int]:
258271
return bytes(data[8 : 8 + length]), total_size
259272

260273

261-
def encode_value(value: Any, value_type: ValueType | None = None) -> tuple[bytes, ValueType]:
274+
def encode_value(value: WireInput, value_type: ValueType | None = None) -> tuple[bytes, ValueType]:
262275
"""Encode a Python value to wire format.
263276
264277
If value_type is not provided, it's inferred from the Python type.
@@ -281,14 +294,19 @@ def encode_value(value: Any, value_type: ValueType | None = None) -> tuple[bytes
281294
value_type = ValueType.FLOAT
282295
elif isinstance(value, str):
283296
value_type = ValueType.TEXT
284-
elif isinstance(value, bytes):
297+
elif isinstance(value, (bytes, bytearray, memoryview)):
298+
# Parity with the explicit BLOB branch and with stdlib
299+
# ``sqlite3``: all three bytes-like types infer to BLOB.
300+
# Callers building payloads via mutation (bytearray) or
301+
# zero-copy slicing (memoryview) no longer need to wrap
302+
# values in ``bytes(...)`` before passing them here.
285303
value_type = ValueType.BLOB
286304
else:
287305
raise EncodeError(
288306
f"Cannot infer wire type for value of type {type(value).__name__!r}. "
289-
f"The wire codec only accepts bool, int, float, str, bytes, or None. "
290-
f"Callers passing datetime/date/etc. must convert to str (for ISO8601) "
291-
f"or int (for UNIXTIME) at the driver layer."
307+
f"The wire codec only accepts bool, int, float, str, bytes-like, "
308+
f"or None. Callers passing datetime/date/etc. must convert to str "
309+
f"(for ISO8601) or int (for UNIXTIME) at the driver layer."
292310
)
293311

294312
if value_type == ValueType.BOOLEAN:
@@ -351,7 +369,7 @@ def encode_value(value: Any, value_type: ValueType | None = None) -> tuple[bytes
351369
raise EncodeError(f"Unknown value type: {value_type}")
352370

353371

354-
def decode_value(data: bytes | memoryview, value_type: ValueType) -> tuple[Any, int]:
372+
def decode_value(data: bytes | memoryview, value_type: ValueType) -> tuple[WireValue, int]:
355373
"""Decode a value from wire format.
356374
357375
Returns (value, bytes_consumed).

tests/test_types.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,21 @@ def test_encode_value_blob_accepts_memoryview(self) -> None:
701701
decoded, _ = decode_value(encoded, ValueType.BLOB)
702702
assert decoded == b"\x03\x04"
703703

704+
def test_encode_value_infers_blob_from_bytearray(self) -> None:
705+
"""Inference path must accept every bytes-like the explicit BLOB
706+
branch accepts, otherwise zero-copy / mutation-built payloads
707+
hit EncodeError on pure inference paths."""
708+
encoded, vtype = encode_value(bytearray(b"\x01\x02\x03"))
709+
assert vtype == ValueType.BLOB
710+
decoded, _ = decode_value(encoded, ValueType.BLOB)
711+
assert decoded == b"\x01\x02\x03"
712+
713+
def test_encode_value_infers_blob_from_memoryview(self) -> None:
714+
encoded, vtype = encode_value(memoryview(b"\x04\x05\x06\x07"))
715+
assert vtype == ValueType.BLOB
716+
decoded, _ = decode_value(encoded, ValueType.BLOB)
717+
assert decoded == b"\x04\x05\x06\x07"
718+
704719
def test_decode_int64_short_data(self) -> None:
705720
with pytest.raises(DecodeError):
706721
decode_int64(b"\x00" * 7)
@@ -823,3 +838,66 @@ def __index__(self) -> int:
823838

824839
with pytest.raises(EncodeError, match="Cannot infer wire type"):
825840
encode_value(OnlyIndex())
841+
842+
843+
class TestWireTypeAliases:
844+
"""Pin the public WireInput / WireValue type aliases so an
845+
accidental widening to ``Any`` or narrowing that drops a supported
846+
runtime type is caught by the test suite.
847+
"""
848+
849+
def test_wire_input_matches_runtime_accepted_types(self) -> None:
850+
"""WireInput is the documented set of types ``encode_value``
851+
accepts. Drift it and a downstream caller would believe their
852+
bound parameter is acceptable when it isn't (or vice-versa).
853+
"""
854+
import typing
855+
856+
from dqlitewire.types import WireInput
857+
858+
# PEP 695 ``type X = Union`` produces a TypeAliasType; the
859+
# underlying union is reachable via ``__value__``.
860+
expected_members = {
861+
bool,
862+
int,
863+
float,
864+
str,
865+
bytes,
866+
bytearray,
867+
memoryview,
868+
type(None),
869+
}
870+
assert set(typing.get_args(WireInput.__value__)) == expected_members
871+
872+
def test_encode_value_parameter_is_wire_input(self) -> None:
873+
import typing
874+
875+
from dqlitewire.types import WireInput, encode_value
876+
877+
hints = typing.get_type_hints(encode_value)
878+
assert hints["value"] is WireInput
879+
880+
def test_wire_value_matches_runtime_decoded_types(self) -> None:
881+
"""WireValue is the union of possible first-element returns of
882+
``decode_value``. The decoder never produces bytearray/memoryview
883+
(it always returns ``bytes``), so this alias is narrower than
884+
WireInput.
885+
"""
886+
import typing
887+
888+
from dqlitewire.types import WireValue
889+
890+
expected_members = {bool, int, float, str, bytes, type(None)}
891+
assert set(typing.get_args(WireValue.__value__)) == expected_members
892+
893+
def test_decode_value_return_is_wire_value(self) -> None:
894+
import typing
895+
896+
from dqlitewire.types import WireValue, decode_value
897+
898+
hints = typing.get_type_hints(decode_value)
899+
# Return is tuple[WireValue, int] — verify the first arg is
900+
# exactly the alias (identity), not a structural copy.
901+
return_args = typing.get_args(hints["return"])
902+
assert return_args[0] is WireValue
903+
assert return_args[1] is int

0 commit comments

Comments
 (0)