Skip to content

Commit 4d3e1dd

Browse files
Wrap body in memoryview to eliminate quadratic decode
Every body decoder in responses.py (RowsResponse, FilesResponse, ServersResponse) and tuples.py walked the body with an integer offset but passed `data[offset:]` into helpers. Each slice of a `bytes` object copies the remaining tail, producing O(N²) cumulative memcpy on messages with many small rows. Measured quadratic scaling: 50k rows took 2s, 100k took 11s, 200k took 48s. Wrap the body in a `memoryview` at the top of each `decode_body` so per-iteration slices are O(1) views. Widen the primitive decoder signatures (decode_uint32/uint64/int64/double/text/blob, decode_row_header, decode_row_values) to accept `bytes | memoryview`. Use `struct.unpack` (already memoryview-compatible) and a chunked accumulating scan for `decode_text`'s NUL search — memoryview has no `.index(bytes)` method, so the scan reads 4 KiB chunks at a time via `bytes(data[a:b]).find(b"\x00")` and accumulates into a list. The chunked accumulation supports arbitrarily long TEXT values (SQLite TEXT columns routinely exceed 4 KiB for JSON / prose) while keeping per-chunk copy cost bounded. Post-fix: 10x input takes ~10x time, not ~100x. Regression tests check both the linear scaling ratio and long-text roundtrip for memoryview inputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 567ef93 commit 4d3e1dd

4 files changed

Lines changed: 252 additions & 46 deletions

File tree

src/dqlitewire/messages/responses.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,23 @@ def encode_body(self) -> bytes:
322322
def decode_body(
323323
cls, data: bytes, schema: int = 0, max_rows: int = DEFAULT_MAX_ROWS
324324
) -> "RowsResponse":
325+
# Wrap in memoryview so per-iteration slices are O(1) rather
326+
# than O(remaining). Without this, a body with many small rows
327+
# triggers quadratic-time decode (issue 228): each
328+
# ``data[offset:]`` allocates a fresh ``bytes`` copy of the
329+
# tail. Memoryview slicing is a view, so slicing is free.
330+
view = memoryview(data)
325331
offset = 0
326332

327333
# Column count
328-
column_count = decode_uint64(data[offset:])
334+
column_count = decode_uint64(view[offset:])
329335
offset += 8
330336

331337
if column_count > _MAX_COLUMN_COUNT:
332338
raise DecodeError(f"Column count {column_count} exceeds maximum {_MAX_COLUMN_COUNT}")
333339

334340
# Bounds check: each column name is at least 8 bytes (null + padding)
335-
remaining = len(data) - offset
341+
remaining = len(view) - offset
336342
if column_count > remaining // 8:
337343
raise DecodeError(
338344
f"Column count {column_count} exceeds maximum possible in "
@@ -342,7 +348,7 @@ def decode_body(
342348
# Column names
343349
column_names: list[str] = []
344350
for _ in range(column_count):
345-
name, consumed = decode_text(data[offset:])
351+
name, consumed = decode_text(view[offset:])
346352
column_names.append(name)
347353
offset += consumed
348354

@@ -354,11 +360,11 @@ def decode_body(
354360
# Zero-column results cannot have row data (each row would be zero
355361
# bytes), so skip the row loop and consume the end marker directly.
356362
if column_count == 0:
357-
if offset + WORD_SIZE > len(data):
363+
if offset + WORD_SIZE > len(view):
358364
raise DecodeError(
359365
"RowsResponse body exhausted without end marker (zero-column result)"
360366
)
361-
marker_byte = data[offset]
367+
marker_byte = view[offset]
362368
if marker_byte == ROW_DONE_BYTE:
363369
has_more = False
364370
elif marker_byte == ROW_PART_BYTE:
@@ -375,9 +381,9 @@ def decode_body(
375381
has_more=has_more,
376382
)
377383

378-
while offset < len(data):
384+
while offset < len(view):
379385
# Read row header; markers are detected byte-by-byte inside
380-
result, consumed = decode_row_header(data[offset:], column_count)
386+
result, consumed = decode_row_header(view[offset:], column_count)
381387
offset += consumed
382388

383389
if result is RowMarker.DONE:
@@ -405,7 +411,7 @@ def decode_body(
405411
column_types = types
406412

407413
# Read row values
408-
values, consumed = decode_row_values(data[offset:], types)
414+
values, consumed = decode_row_values(view[offset:], types)
409415
rows.append(values)
410416
offset += consumed
411417

@@ -414,7 +420,7 @@ def decode_body(
414420

415421
raise DecodeError(
416422
f"RowsResponse body exhausted without end marker "
417-
f"(decoded {len(rows)} rows, consumed {offset} of {len(data)} bytes)"
423+
f"(decoded {len(rows)} rows, consumed {offset} of {len(view)} bytes)"
418424
)
419425

420426

@@ -464,30 +470,32 @@ def encode_body(self) -> bytes:
464470

465471
@classmethod
466472
def decode_body(cls, data: bytes, schema: int = 0) -> "FilesResponse":
473+
# Memoryview for O(1) slicing in the per-file loop (issue 228).
474+
view = memoryview(data)
467475
files: dict[str, bytes] = {}
468476
offset = 0
469-
count = decode_uint64(data[offset:])
477+
count = decode_uint64(view[offset:])
470478
offset += 8
471479
if count > _MAX_FILE_COUNT:
472480
raise DecodeError(f"File count {count} exceeds maximum {_MAX_FILE_COUNT}")
473481
# Bounds check: each file is at least 16 bytes (name + size)
474-
remaining = len(data) - offset
482+
remaining = len(view) - offset
475483
if count > remaining // 16:
476484
raise DecodeError(
477485
f"File count {count} exceeds maximum possible in "
478486
f"{remaining} bytes of remaining data"
479487
)
480488
for _ in range(count):
481-
name, consumed = decode_text(data[offset:])
489+
name, consumed = decode_text(view[offset:])
482490
offset += consumed
483-
size = decode_uint64(data[offset:])
491+
size = decode_uint64(view[offset:])
484492
offset += 8
485-
if offset + size > len(data):
493+
if offset + size > len(view):
486494
raise DecodeError(
487495
f"FilesResponse file content truncated: expected {size} bytes "
488-
f"at offset {offset}, but only {len(data) - offset} bytes available"
496+
f"at offset {offset}, but only {len(view) - offset} bytes available"
489497
)
490-
content = data[offset : offset + size]
498+
content = bytes(view[offset : offset + size])
491499
# No padding after content — matches Go's byte-by-byte read.
492500
offset += size
493501
files[name] = content
@@ -524,25 +532,27 @@ def encode_body(self) -> bytes:
524532

525533
@classmethod
526534
def decode_body(cls, data: bytes, schema: int = 0) -> "ServersResponse":
535+
# Memoryview for O(1) slicing in the per-node loop (issue 228).
536+
view = memoryview(data)
527537
nodes: list[NodeInfo] = []
528538
offset = 0
529-
count = decode_uint64(data[offset:])
539+
count = decode_uint64(view[offset:])
530540
offset += 8
531541
if count > _MAX_NODE_COUNT:
532542
raise DecodeError(f"Node count {count} exceeds maximum {_MAX_NODE_COUNT}")
533543
# Bounds check: each node is at least 24 bytes (id + address + role)
534-
remaining = len(data) - offset
544+
remaining = len(view) - offset
535545
if count > remaining // 24:
536546
raise DecodeError(
537547
f"Node count {count} exceeds maximum possible in "
538548
f"{remaining} bytes of remaining data"
539549
)
540550
for _ in range(count):
541-
node_id = decode_uint64(data[offset:])
551+
node_id = decode_uint64(view[offset:])
542552
offset += 8
543-
address, consumed = decode_text(data[offset:])
553+
address, consumed = decode_text(view[offset:])
544554
offset += consumed
545-
role = decode_uint64(data[offset:])
555+
role = decode_uint64(view[offset:])
546556
offset += 8
547557
nodes.append(NodeInfo(node_id, address, role))
548558
return cls(nodes)

src/dqlitewire/tuples.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def encode_row_header(types: Sequence[ValueType]) -> bytes:
203203
return bytes(header)
204204

205205

206-
def decode_row_header(data: bytes, column_count: int) -> tuple[list[ValueType] | RowMarker, int]:
206+
def decode_row_header(
207+
data: bytes | memoryview, column_count: int
208+
) -> tuple[list[ValueType] | RowMarker, int]:
207209
"""Decode row column type header.
208210
209211
Format: 4-bit type codes packed two per byte, padded to word boundary.
@@ -255,7 +257,9 @@ def encode_row_values(values: Sequence[Any], types: Sequence[ValueType]) -> byte
255257
return bytes(result)
256258

257259

258-
def decode_row_values(data: bytes, types: Sequence[ValueType]) -> tuple[list[Any], int]:
260+
def decode_row_values(
261+
data: bytes | memoryview, types: Sequence[ValueType]
262+
) -> tuple[list[Any], int]:
259263
"""Decode row values according to column types.
260264
261265
Returns (values, bytes_consumed).

src/dqlitewire/types.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ def encode_uint64(value: int) -> bytes:
1919
return struct.pack("<Q", value)
2020

2121

22-
def decode_uint64(data: bytes) -> int:
23-
"""Decode an unsigned 64-bit integer (little-endian)."""
22+
def decode_uint64(data: bytes | memoryview) -> int:
23+
"""Decode an unsigned 64-bit integer (little-endian).
24+
25+
Accepts ``bytes`` or ``memoryview`` so hot-path body decoders
26+
(issue 228) can pass memoryview slices without copying.
27+
"""
2428
if len(data) < 8:
2529
raise DecodeError(f"Need 8 bytes for uint64, got {len(data)}")
2630
result: int = struct.unpack("<Q", data[:8])[0]
@@ -34,8 +38,11 @@ def encode_int64(value: int) -> bytes:
3438
return struct.pack("<q", value)
3539

3640

37-
def decode_int64(data: bytes) -> int:
38-
"""Decode a signed 64-bit integer (little-endian)."""
41+
def decode_int64(data: bytes | memoryview) -> int:
42+
"""Decode a signed 64-bit integer (little-endian).
43+
44+
Accepts ``bytes`` or ``memoryview`` (issue 228).
45+
"""
3946
if len(data) < 8:
4047
raise DecodeError(f"Need 8 bytes for int64, got {len(data)}")
4148
result: int = struct.unpack("<q", data[:8])[0]
@@ -49,8 +56,11 @@ def encode_uint32(value: int) -> bytes:
4956
return struct.pack("<I", value)
5057

5158

52-
def decode_uint32(data: bytes) -> int:
53-
"""Decode an unsigned 32-bit integer (little-endian)."""
59+
def decode_uint32(data: bytes | memoryview) -> int:
60+
"""Decode an unsigned 32-bit integer (little-endian).
61+
62+
Accepts ``bytes`` or ``memoryview`` (issue 228).
63+
"""
5464
if len(data) < 4:
5565
raise DecodeError(f"Need 4 bytes for uint32, got {len(data)}")
5666
result: int = struct.unpack("<I", data[:4])[0]
@@ -66,11 +76,12 @@ def encode_double(value: float) -> bytes:
6676
return struct.pack("<d", value)
6777

6878

69-
def decode_double(data: bytes) -> float:
79+
def decode_double(data: bytes | memoryview) -> float:
7080
"""Decode a 64-bit floating point number (little-endian).
7181
7282
All IEEE 754 values are accepted, including NaN and infinity,
73-
matching the Go reference implementation behavior.
83+
matching the Go reference implementation behavior. Accepts
84+
``bytes`` or ``memoryview`` (issue 228).
7485
"""
7586
if len(data) < 8:
7687
raise DecodeError(f"Need 8 bytes for double, got {len(data)}")
@@ -103,21 +114,60 @@ def encode_text(value: str) -> bytes:
103114
return encoded + (b"\x00" * padding)
104115

105116

106-
def decode_text(data: bytes) -> tuple[str, int]:
117+
_TEXT_SCAN_CHUNK = 4096
118+
119+
120+
def decode_text(data: bytes | memoryview) -> tuple[str, int]:
107121
"""Decode null-terminated UTF-8 text.
108122
109-
Returns the decoded string and the number of bytes consumed (including padding).
123+
Accepts either ``bytes`` or ``memoryview``. Returns the decoded
124+
string and the number of bytes consumed (including padding).
125+
126+
The decoder's hot body loops (RowsResponse, FilesResponse,
127+
ServersResponse) wrap the body in a ``memoryview`` so
128+
per-iteration slices are O(1) rather than O(remaining) — see
129+
issue 228. ``bytes`` inputs use zero-copy ``.index(b"\\x00")``.
130+
``memoryview`` inputs walk the buffer in fixed-size chunks so the
131+
per-chunk ``bytes(...)`` copy is bounded; arbitrarily long text
132+
values (e.g. multi-KiB SQL strings or TEXT column values) are
133+
supported because the scan simply visits more chunks. Per-call
134+
cost scales with the actual text length, not with the remaining
135+
body.
110136
"""
111-
# Find null terminator
112-
try:
113-
null_pos = data.index(b"\x00")
114-
except ValueError as e:
115-
raise DecodeError("Text not null-terminated") from e
137+
if isinstance(data, memoryview):
138+
# Memoryview has no ``.index(bytes)``. Scan in fixed chunks and
139+
# accumulate so we can decode the full text without re-copying
140+
# after the NUL is found.
141+
chunks: list[bytes] = []
142+
scanned = 0
143+
null_pos = -1
144+
data_len = len(data)
145+
while scanned < data_len:
146+
chunk_end = min(scanned + _TEXT_SCAN_CHUNK, data_len)
147+
chunk = bytes(data[scanned:chunk_end])
148+
local = chunk.find(b"\x00")
149+
if local >= 0:
150+
chunks.append(chunk[:local])
151+
null_pos = scanned + local
152+
break
153+
chunks.append(chunk)
154+
scanned = chunk_end
155+
if null_pos < 0:
156+
raise DecodeError("Text not null-terminated")
157+
try:
158+
text = b"".join(chunks).decode("utf-8")
159+
except UnicodeDecodeError as e:
160+
raise DecodeError(f"Invalid UTF-8 in text field: {e}") from e
161+
else:
162+
try:
163+
null_pos = data.index(b"\x00")
164+
except ValueError as e:
165+
raise DecodeError("Text not null-terminated") from e
166+
try:
167+
text = data[:null_pos].decode("utf-8")
168+
except UnicodeDecodeError as e:
169+
raise DecodeError(f"Invalid UTF-8 in text field: {e}") from e
116170

117-
try:
118-
text = data[:null_pos].decode("utf-8")
119-
except UnicodeDecodeError as e:
120-
raise DecodeError(f"Invalid UTF-8 in text field: {e}") from e
121171
# Calculate total size including padding
122172
total_size = null_pos + 1 + pad_to_word(null_pos + 1)
123173
if len(data) < total_size:
@@ -135,10 +185,11 @@ def encode_blob(value: bytes) -> bytes:
135185
return encode_uint64(length) + value + (b"\x00" * padding)
136186

137187

138-
def decode_blob(data: bytes) -> tuple[bytes, int]:
188+
def decode_blob(data: bytes | memoryview) -> tuple[bytes, int]:
139189
"""Decode a blob.
140190
141-
Returns the blob data and the number of bytes consumed.
191+
Accepts either ``bytes`` or ``memoryview``. Returns the blob data
192+
(always as ``bytes``) and the number of bytes consumed.
142193
"""
143194
if len(data) < 8:
144195
raise DecodeError("Not enough data for blob length")
@@ -149,7 +200,7 @@ def decode_blob(data: bytes) -> tuple[bytes, int]:
149200
if len(data) < total_size:
150201
raise DecodeError(f"Not enough data for blob: need {total_size}, got {len(data)}")
151202

152-
return data[8 : 8 + length], total_size
203+
return bytes(data[8 : 8 + length]), total_size
153204

154205

155206
def _format_datetime_iso8601(value: datetime.datetime) -> str:
@@ -276,7 +327,7 @@ def _parse_iso8601(text: str) -> datetime.datetime:
276327
raise DecodeError(f"Cannot parse ISO 8601 datetime: {text!r}")
277328

278329

279-
def decode_value(data: bytes, value_type: ValueType) -> tuple[Any, int]:
330+
def decode_value(data: bytes | memoryview, value_type: ValueType) -> tuple[Any, int]:
280331
"""Decode a value from wire format.
281332
282333
Returns (value, bytes_consumed).

0 commit comments

Comments
 (0)