Skip to content

Commit 7732f73

Browse files
feat(client): validate max_total_rows constructor argument
Add _validate_max_total_rows helper in protocol.py that rejects non-int, bool, and non-positive values. Apply in DqliteConnection, ConnectionPool, and DqliteProtocol constructors so typos fail fast with a clear message instead of surfacing later as obscure "Query exceeded max_total_rows cap (0)" errors at query time. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 05767d6 commit 7732f73

File tree

4 files changed

+106
-4
lines changed

4 files changed

+106
-4
lines changed

src/dqliteclient/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
OperationalError,
1515
ProtocolError,
1616
)
17-
from dqliteclient.protocol import DqliteProtocol
17+
from dqliteclient.protocol import DqliteProtocol, _validate_max_total_rows
1818
from dqlitewire.exceptions import EncodeError as _WireEncodeError
1919

2020
# dqlite error codes that indicate a leader change (SQLite extended error codes)
@@ -93,7 +93,7 @@ def __init__(
9393
self._address = address
9494
self._database = database
9595
self._timeout = timeout
96-
self._max_total_rows = max_total_rows
96+
self._max_total_rows = _validate_max_total_rows(max_total_rows)
9797
self._protocol: DqliteProtocol | None = None
9898
self._db_id: int | None = None
9999
self._in_transaction = False

src/dqliteclient/pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dqliteclient.connection import DqliteConnection
1111
from dqliteclient.exceptions import DqliteConnectionError
1212
from dqliteclient.node_store import NodeStore
13+
from dqliteclient.protocol import _validate_max_total_rows
1314

1415

1516
def _socket_looks_dead(conn: DqliteConnection) -> bool:
@@ -98,7 +99,7 @@ def __init__(
9899
self._min_size = min_size
99100
self._max_size = max_size
100101
self._timeout = timeout
101-
self._max_total_rows = max_total_rows
102+
self._max_total_rows = _validate_max_total_rows(max_total_rows)
102103

103104
if cluster is not None:
104105
self._cluster = cluster

src/dqliteclient/protocol.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@
3232
_READ_CHUNK_SIZE = 4096
3333

3434

35+
def _validate_max_total_rows(value: int | None) -> int | None:
36+
"""Validate the ``max_total_rows`` constructor argument.
37+
38+
``None`` disables the cap. Otherwise the value must be a positive
39+
``int`` (``bool`` is rejected even though it's a subclass of int).
40+
"""
41+
if value is None:
42+
return None
43+
if isinstance(value, bool) or not isinstance(value, int):
44+
raise TypeError(f"max_total_rows must be int or None, got {type(value).__name__}")
45+
if value <= 0:
46+
raise ValueError(f"max_total_rows must be > 0 or None, got {value}")
47+
return value
48+
49+
3550
class DqliteProtocol:
3651
"""Low-level protocol handler for a single dqlite connection."""
3752

@@ -53,7 +68,7 @@ def __init__(
5368
# the per-operation deadline; without a cumulative cap, clients
5469
# could legitimately allocate hundreds of millions of rows over
5570
# the full deadline. None disables the cap.
56-
self._max_total_rows = max_total_rows
71+
self._max_total_rows = _validate_max_total_rows(max_total_rows)
5772

5873
async def handshake(self, client_id: int | None = None) -> int:
5974
"""Perform protocol handshake.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Validation of the ``max_total_rows`` constructor parameter."""
2+
3+
import asyncio
4+
5+
import pytest
6+
7+
from dqliteclient.connection import DqliteConnection
8+
from dqliteclient.pool import ConnectionPool
9+
from dqliteclient.protocol import DqliteProtocol, _validate_max_total_rows
10+
11+
12+
class TestValidator:
13+
def test_none_allowed(self) -> None:
14+
assert _validate_max_total_rows(None) is None
15+
16+
def test_positive_int_allowed(self) -> None:
17+
assert _validate_max_total_rows(1) == 1
18+
assert _validate_max_total_rows(10_000_000) == 10_000_000
19+
20+
def test_zero_rejected(self) -> None:
21+
with pytest.raises(ValueError, match="max_total_rows must be > 0"):
22+
_validate_max_total_rows(0)
23+
24+
def test_negative_rejected(self) -> None:
25+
with pytest.raises(ValueError, match="max_total_rows must be > 0"):
26+
_validate_max_total_rows(-1)
27+
28+
def test_float_rejected(self) -> None:
29+
with pytest.raises(TypeError, match="max_total_rows must be int or None"):
30+
_validate_max_total_rows(1.5) # type: ignore[arg-type]
31+
32+
def test_bool_rejected(self) -> None:
33+
# True is technically int, but PEP-489-style APIs rightly reject it.
34+
with pytest.raises(TypeError, match="max_total_rows must be int or None"):
35+
_validate_max_total_rows(True) # type: ignore[arg-type]
36+
37+
def test_string_rejected(self) -> None:
38+
with pytest.raises(TypeError, match="max_total_rows must be int or None"):
39+
_validate_max_total_rows("100") # type: ignore[arg-type]
40+
41+
42+
class TestConstructorValidation:
43+
def test_dqlite_connection_zero_rejected(self) -> None:
44+
with pytest.raises(ValueError):
45+
DqliteConnection("localhost:19001", max_total_rows=0)
46+
47+
def test_dqlite_connection_negative_rejected(self) -> None:
48+
with pytest.raises(ValueError):
49+
DqliteConnection("localhost:19001", max_total_rows=-5)
50+
51+
def test_dqlite_connection_bool_rejected(self) -> None:
52+
with pytest.raises(TypeError):
53+
DqliteConnection("localhost:19001", max_total_rows=True) # type: ignore[arg-type]
54+
55+
def test_dqlite_connection_none_allowed(self) -> None:
56+
conn = DqliteConnection("localhost:19001", max_total_rows=None)
57+
assert conn._max_total_rows is None
58+
59+
def test_pool_zero_rejected(self) -> None:
60+
with pytest.raises(ValueError):
61+
ConnectionPool(addresses=["localhost:19001"], max_total_rows=0)
62+
63+
def test_pool_negative_rejected(self) -> None:
64+
with pytest.raises(ValueError):
65+
ConnectionPool(addresses=["localhost:19001"], max_total_rows=-1)
66+
67+
@pytest.mark.asyncio
68+
async def test_protocol_zero_rejected(self) -> None:
69+
reader = asyncio.StreamReader()
70+
writer = _DummyWriter()
71+
with pytest.raises(ValueError):
72+
DqliteProtocol(reader, writer, max_total_rows=0) # type: ignore[arg-type]
73+
74+
75+
class _DummyWriter:
76+
def close(self) -> None:
77+
pass
78+
79+
async def wait_closed(self) -> None:
80+
pass
81+
82+
def write(self, data: bytes) -> None:
83+
pass
84+
85+
async def drain(self) -> None:
86+
pass

0 commit comments

Comments
 (0)