Skip to content

Commit dbf1ea5

Browse files
perf: wrap body in memoryview to eliminate quadratic decode
Every body decoder in responses.py (RowsResponse, FilesResponse, ServersResponse) and tuples.py walked the body with an integer offset but passed `data[offset:]` into helpers. Each slice of a `bytes` object copies the remaining tail, producing O(N²) cumulative memcpy on messages with many small rows. Measured quadratic scaling: 50k rows took 2s, 100k took 11s, 200k took 48s. Wrap the body in a `memoryview` at the top of each `decode_body` so per-iteration slices are O(1) views. Widen the primitive decoder signatures (decode_uint32/uint64/int64/double/text/blob, decode_row_header, decode_row_values) to accept `bytes | memoryview`. Use `struct.unpack` (already memoryview-compatible) and a chunked accumulating scan for `decode_text`'s NUL search — memoryview has no `.index(bytes)` method, so the scan reads 4 KiB chunks at a time via `bytes(data[a:b]).find(b"\x00")` and accumulates into a list. The chunked accumulation supports arbitrarily long TEXT values (SQLite TEXT columns routinely exceed 4 KiB for JSON / prose) while keeping per-chunk copy cost bounded. Post-fix: 10x input takes ~10x time, not ~100x. Regression tests check both the linear scaling ratio and long-text roundtrip for memoryview inputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9db8390 commit dbf1ea5

4 files changed

Lines changed: 252 additions & 46 deletions

File tree

src/dqlitewire/messages/responses.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,23 @@ def encode_body(self) -> bytes:
322322
def decode_body(
323323
cls, data: bytes, schema: int = 0, max_rows: int = DEFAULT_MAX_ROWS
324324
) -> "RowsResponse":
325+
# Wrap in memoryview so per-iteration slices are O(1) rather
326+
# than O(remaining). Without this, a body with many small rows
327+
# triggers quadratic-time decode (issue 228): each
328+
# ``data[offset:]`` allocates a fresh ``bytes`` copy of the
329+
# tail. Memoryview slicing is a view, so slicing is free.
330+
view = memoryview(data)
325331
offset = 0
326332

327333
# Column count
328-
column_count = decode_uint64(data[offset:])
334+
column_count = decode_uint64(view[offset:])
329335
offset += 8
330336

331337
if column_count > _MAX_COLUMN_COUNT:
332338
raise DecodeError(f"Column count {column_count} exceeds maximum {_MAX_COLUMN_COUNT}")
333339

334340
# Bounds check: each column name is at least 8 bytes (null + padding)
335-
remaining = len(data) - offset
341+
remaining = len(view) - offset
336342
if column_count > remaining // 8:
337343
raise DecodeError(
338344
f"Column count {column_count} exceeds maximum possible in "
@@ -342,7 +348,7 @@ def decode_body(
342348
# Column names
343349
column_names: list[str] = []
344350
for _ in range(column_count):
345-
name, consumed = decode_text(data[offset:])
351+
name, consumed = decode_text(view[offset:])
346352
column_names.append(name)
347353
offset += consumed
348354

@@ -354,11 +360,11 @@ def decode_body(
354360
# Zero-column results cannot have row data (each row would be zero
355361
# bytes), so skip the row loop and consume the end marker directly.
356362
if column_count == 0:
357-
if offset + WORD_SIZE > len(data):
363+
if offset + WORD_SIZE > len(view):
358364
raise DecodeError(
359365
"RowsResponse body exhausted without end marker (zero-column result)"
360366
)
361-
marker_byte = data[offset]
367+
marker_byte = view[offset]
362368
if marker_byte == ROW_DONE_BYTE:
363369
has_more = False
364370
elif marker_byte == ROW_PART_BYTE:
@@ -375,9 +381,9 @@ def decode_body(
375381
has_more=has_more,
376382
)
377383

378-
while offset < len(data):
384+
while offset < len(view):
379385
# Read row header; markers are detected byte-by-byte inside
380-
result, consumed = decode_row_header(data[offset:], column_count)
386+
result, consumed = decode_row_header(view[offset:], column_count)
381387
offset += consumed
382388

383389
if result is RowMarker.DONE:
@@ -405,7 +411,7 @@ def decode_body(
405411
column_types = types
406412

407413
# Read row values
408-
values, consumed = decode_row_values(data[offset:], types)
414+
values, consumed = decode_row_values(view[offset:], types)
409415
rows.append(values)
410416
offset += consumed
411417

@@ -414,7 +420,7 @@ def decode_body(
414420

415421
raise DecodeError(
416422
f"RowsResponse body exhausted without end marker "
417-
f"(decoded {len(rows)} rows, consumed {offset} of {len(data)} bytes)"
423+
f"(decoded {len(rows)} rows, consumed {offset} of {len(view)} bytes)"
418424
)
419425

420426

@@ -464,30 +470,32 @@ def encode_body(self) -> bytes:
464470

465471
@classmethod
466472
def decode_body(cls, data: bytes, schema: int = 0) -> "FilesResponse":
473+
# Memoryview for O(1) slicing in the per-file loop (issue 228).
474+
view = memoryview(data)
467475
files: dict[str, bytes] = {}
468476
offset = 0
469-
count = decode_uint64(data[offset:])
477+
count = decode_uint64(view[offset:])
470478
offset += 8
471479
if count > _MAX_FILE_COUNT:
472480
raise DecodeError(f"File count {count} exceeds maximum {_MAX_FILE_COUNT}")
473481
# Bounds check: each file is at least 16 bytes (name + size)
474-
remaining = len(data) - offset
482+
remaining = len(view) - offset
475483
if count > remaining // 16:
476484
raise DecodeError(
477485
f"File count {count} exceeds maximum possible in "
478486
f"{remaining} bytes of remaining data"
479487
)
480488
for _ in range(count):
481-
name, consumed = decode_text(data[offset:])
489+
name, consumed = decode_text(view[offset:])
482490
offset += consumed
483-
size = decode_uint64(data[offset:])
491+
size = decode_uint64(view[offset:])
484492
offset += 8
485-
if offset + size > len(data):
493+
if offset + size > len(view):
486494
raise DecodeError(
487495
f"FilesResponse file content truncated: expected {size} bytes "
488-
f"at offset {offset}, but only {len(data) - offset} bytes available"
496+
f"at offset {offset}, but only {len(view) - offset} bytes available"
489497
)
490-
content = data[offset : offset + size]
498+
content = bytes(view[offset : offset + size])
491499
# No padding after content — matches Go's byte-by-byte read.
492500
offset += size
493501
files[name] = content
@@ -524,25 +532,27 @@ def encode_body(self) -> bytes:
524532

525533
@classmethod
526534
def decode_body(cls, data: bytes, schema: int = 0) -> "ServersResponse":
535+
# Memoryview for O(1) slicing in the per-node loop (issue 228).
536+
view = memoryview(data)
527537
nodes: list[NodeInfo] = []
528538
offset = 0
529-
count = decode_uint64(data[offset:])
539+
count = decode_uint64(view[offset:])
530540
offset += 8
531541
if count > _MAX_NODE_COUNT:
532542
raise DecodeError(f"Node count {count} exceeds maximum {_MAX_NODE_COUNT}")
533543
# Bounds check: each node is at least 24 bytes (id + address + role)
534-
remaining = len(data) - offset
544+
remaining = len(view) - offset
535545
if count > remaining // 24:
536546
raise DecodeError(
537547
f"Node count {count} exceeds maximum possible in "
538548
f"{remaining} bytes of remaining data"
539549
)
540550
for _ in range(count):
541-
node_id = decode_uint64(data[offset:])
551+
node_id = decode_uint64(view[offset:])
542552
offset += 8
543-
address, consumed = decode_text(data[offset:])
553+
address, consumed = decode_text(view[offset:])
544554
offset += consumed
545-
role = decode_uint64(data[offset:])
555+
role = decode_uint64(view[offset:])
546556
offset += 8
547557
nodes.append(NodeInfo(node_id, address, role))
548558
return cls(nodes)

src/dqlitewire/tuples.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def encode_row_header(types: Sequence[ValueType]) -> bytes:
203203
return bytes(header)
204204

205205

206-
def decode_row_header(data: bytes, column_count: int) -> tuple[list[ValueType] | RowMarker, int]:
206+
def decode_row_header(
207+
data: bytes | memoryview, column_count: int
208+
) -> tuple[list[ValueType] | RowMarker, int]:
207209
"""Decode row column type header.
208210
209211
Format: 4-bit type codes packed two per byte, padded to word boundary.
@@ -255,7 +257,9 @@ def encode_row_values(values: Sequence[Any], types: Sequence[ValueType]) -> byte
255257
return bytes(result)
256258

257259

258-
def decode_row_values(data: bytes, types: Sequence[ValueType]) -> tuple[list[Any], int]:
260+
def decode_row_values(
261+
data: bytes | memoryview, types: Sequence[ValueType]
262+
) -> tuple[list[Any], int]:
259263
"""Decode row values according to column types.
260264
261265
Returns (values, bytes_consumed).

src/dqlitewire/types.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ def encode_uint64(value: int) -> bytes:
1919
return struct.pack("<Q", value)
2020

2121

22-
def decode_uint64(data: bytes) -> int:
23-
"""Decode an unsigned 64-bit integer (little-endian)."""
22+
def decode_uint64(data: bytes | memoryview) -> int:
23+
"""Decode an unsigned 64-bit integer (little-endian).
24+
25+
Accepts ``bytes`` or ``memoryview`` so hot-path body decoders
26+
(issue 228) can pass memoryview slices without copying.
27+
"""
2428
if len(data) < 8:
2529
raise DecodeError(f"Need 8 bytes for uint64, got {len(data)}")
2630
result: int = struct.unpack("<Q", data[:8])[0]
@@ -34,8 +38,11 @@ def encode_int64(value: int) -> bytes:
3438
return struct.pack("<q", value)
3539

3640

37-
def decode_int64(data: bytes) -> int:
38-
"""Decode a signed 64-bit integer (little-endian)."""
41+
def decode_int64(data: bytes | memoryview) -> int:
42+
"""Decode a signed 64-bit integer (little-endian).
43+
44+
Accepts ``bytes`` or ``memoryview`` (issue 228).
45+
"""
3946
if len(data) < 8:
4047
raise DecodeError(f"Need 8 bytes for int64, got {len(data)}")
4148
result: int = struct.unpack("<q", data[:8])[0]
@@ -49,8 +56,11 @@ def encode_uint32(value: int) -> bytes:
4956
return struct.pack("<I", value)
5057

5158

52-
def decode_uint32(data: bytes) -> int:
53-
"""Decode an unsigned 32-bit integer (little-endian)."""
59+
def decode_uint32(data: bytes | memoryview) -> int:
60+
"""Decode an unsigned 32-bit integer (little-endian).
61+
62+
Accepts ``bytes`` or ``memoryview`` (issue 228).
63+
"""
5464
if len(data) < 4:
5565
raise DecodeError(f"Need 4 bytes for uint32, got {len(data)}")
5666
result: int = struct.unpack("<I", data[:4])[0]
@@ -66,11 +76,12 @@ def encode_double(value: float) -> bytes:
6676
return struct.pack("<d", value)
6777

6878

69-
def decode_double(data: bytes) -> float:
79+
def decode_double(data: bytes | memoryview) -> float:
7080
"""Decode a 64-bit floating point number (little-endian).
7181
7282
All IEEE 754 values are accepted, including NaN and infinity,
73-
matching the Go reference implementation behavior.
83+
matching the Go reference implementation behavior. Accepts
84+
``bytes`` or ``memoryview`` (issue 228).
7485
"""
7586
if len(data) < 8:
7687
raise DecodeError(f"Need 8 bytes for double, got {len(data)}")
@@ -103,21 +114,60 @@ def encode_text(value: str) -> bytes:
103114
return encoded + (b"\x00" * padding)
104115

105116

106-
def decode_text(data: bytes) -> tuple[str, int]:
117+
_TEXT_SCAN_CHUNK = 4096
118+
119+
120+
def decode_text(data: bytes | memoryview) -> tuple[str, int]:
107121
"""Decode null-terminated UTF-8 text.
108122
109-
Returns the decoded string and the number of bytes consumed (including padding).
123+
Accepts either ``bytes`` or ``memoryview``. Returns the decoded
124+
string and the number of bytes consumed (including padding).
125+
126+
The decoder's hot body loops (RowsResponse, FilesResponse,
127+
ServersResponse) wrap the body in a ``memoryview`` so
128+
per-iteration slices are O(1) rather than O(remaining) — see
129+
issue 228. ``bytes`` inputs use zero-copy ``.index(b"\\x00")``.
130+
``memoryview`` inputs walk the buffer in fixed-size chunks so the
131+
per-chunk ``bytes(...)`` copy is bounded; arbitrarily long text
132+
values (e.g. multi-KiB SQL strings or TEXT column values) are
133+
supported because the scan simply visits more chunks. Per-call
134+
cost scales with the actual text length, not with the remaining
135+
body.
110136
"""
111-
# Find null terminator
112-
try:
113-
null_pos = data.index(b"\x00")
114-
except ValueError as e:
115-
raise DecodeError("Text not null-terminated") from e
137+
if isinstance(data, memoryview):
138+
# Memoryview has no ``.index(bytes)``. Scan in fixed chunks and
139+
# accumulate so we can decode the full text without re-copying
140+
# after the NUL is found.
141+
chunks: list[bytes] = []
142+
scanned = 0
143+
null_pos = -1
144+
data_len = len(data)
145+
while scanned < data_len:
146+
chunk_end = min(scanned + _TEXT_SCAN_CHUNK, data_len)
147+
chunk = bytes(data[scanned:chunk_end])
148+
local = chunk.find(b"\x00")
149+
if local >= 0:
150+
chunks.append(chunk[:local])
151+
null_pos = scanned + local
152+
break
153+
chunks.append(chunk)
154+
scanned = chunk_end
155+
if null_pos < 0:
156+
raise DecodeError("Text not null-terminated")
157+
try:
158+
text = b"".join(chunks).decode("utf-8")
159+
except UnicodeDecodeError as e:
160+
raise DecodeError(f"Invalid UTF-8 in text field: {e}") from e
161+
else:
162+
try:
163+
null_pos = data.index(b"\x00")
164+
except ValueError as e:
165+
raise DecodeError("Text not null-terminated") from e
166+
try:
167+
text = data[:null_pos].decode("utf-8")
168+
except UnicodeDecodeError as e:
169+
raise DecodeError(f"Invalid UTF-8 in text field: {e}") from e
116170

117-
try:
118-
text = data[:null_pos].decode("utf-8")
119-
except UnicodeDecodeError as e:
120-
raise DecodeError(f"Invalid UTF-8 in text field: {e}") from e
121171
# Calculate total size including padding
122172
total_size = null_pos + 1 + pad_to_word(null_pos + 1)
123173
if len(data) < total_size:
@@ -135,10 +185,11 @@ def encode_blob(value: bytes) -> bytes:
135185
return encode_uint64(length) + value + (b"\x00" * padding)
136186

137187

138-
def decode_blob(data: bytes) -> tuple[bytes, int]:
188+
def decode_blob(data: bytes | memoryview) -> tuple[bytes, int]:
139189
"""Decode a blob.
140190
141-
Returns the blob data and the number of bytes consumed.
191+
Accepts either ``bytes`` or ``memoryview``. Returns the blob data
192+
(always as ``bytes``) and the number of bytes consumed.
142193
"""
143194
if len(data) < 8:
144195
raise DecodeError("Not enough data for blob length")
@@ -149,7 +200,7 @@ def decode_blob(data: bytes) -> tuple[bytes, int]:
149200
if len(data) < total_size:
150201
raise DecodeError(f"Not enough data for blob: need {total_size}, got {len(data)}")
151202

152-
return data[8 : 8 + length], total_size
203+
return bytes(data[8 : 8 + length]), total_size
153204

154205

155206
def _format_datetime_iso8601(value: datetime.datetime) -> str:
@@ -276,7 +327,7 @@ def _parse_iso8601(text: str) -> datetime.datetime:
276327
raise DecodeError(f"Cannot parse ISO 8601 datetime: {text!r}")
277328

278329

279-
def decode_value(data: bytes, value_type: ValueType) -> tuple[Any, int]:
330+
def decode_value(data: bytes | memoryview, value_type: ValueType) -> tuple[Any, int]:
280331
"""Decode a value from wire format.
281332
282333
Returns (value, bytes_consumed).

0 commit comments

Comments
 (0)