Skip to content

Commit 906fe91

Browse files
Cycles 16-17: exercise max_total_rows cap + allowlist_policy edges
- New test_max_total_rows_cap.py triggers the _drain_continuations cap directly via mocked continuation frames: raises at exceed, accepts at the boundary, None disables. - allowlist_policy now takes Iterable[str] and materializes to frozenset once (safe for generators, dict_keys, etc.). New tests cover: empty allowlist rejects all redirects; iterator input works across repeated calls (proves one-shot materialization). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ce6a941 commit 906fe91

3 files changed

Lines changed: 112 additions & 3 deletions

File tree

src/dqliteclient/cluster.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import random
5-
from collections.abc import Callable
5+
from collections.abc import Callable, Iterable
66

77
from dqliteclient.connection import DqliteConnection, _parse_address
88
from dqliteclient.exceptions import (
@@ -191,15 +191,19 @@ async def update_nodes(self, nodes: list[NodeInfo]) -> None:
191191
await self._node_store.set_nodes(nodes)
192192

193193

194-
def allowlist_policy(addresses: list[str] | set[str]) -> RedirectPolicy:
194+
def allowlist_policy(addresses: Iterable[str]) -> RedirectPolicy:
195195
"""Build a redirect policy that accepts only the given addresses.
196196
197197
Useful for the common case: "only allow redirects to hosts I've
198198
explicitly seed-listed." Addresses are matched by exact string
199199
equality — callers that need CIDR / DNS / wildcard matching should
200200
supply their own callable.
201+
202+
Accepts any iterable (list, set, tuple, generator, dict_keys). The
203+
iterable is materialized into a frozen set once, so passing a
204+
generator is safe — the returned closure doesn't re-iterate.
201205
"""
202-
allowed = set(addresses)
206+
allowed = frozenset(addresses)
203207

204208
def policy(addr: str) -> bool:
205209
return addr in allowed

tests/test_max_total_rows_cap.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""max_total_rows cap actually fires in the continuation-drain loop.
2+
3+
Cycle 9 wired the cap through the layers; cycle 16 adds the missing
4+
test that exercises the enforcement itself. Uses a mocked protocol
5+
response stream so we don't need a cluster to deliver millions of rows.
6+
"""
7+
8+
import asyncio
9+
from unittest.mock import AsyncMock, MagicMock
10+
11+
import pytest
12+
13+
from dqliteclient.exceptions import ProtocolError
14+
from dqliteclient.protocol import DqliteProtocol
15+
16+
17+
def _make_rows_response(rows: list[list[object]], has_more: bool) -> MagicMock:
18+
r = MagicMock(name="RowsResponse")
19+
r.rows = rows
20+
r.has_more = has_more
21+
r.column_names = ["v"]
22+
r.column_types = [1] # INTEGER
23+
return r
24+
25+
26+
class TestMaxTotalRowsEnforcement:
27+
def test_exceeding_cap_raises_protocol_error(self) -> None:
28+
"""A continuation frame that pushes us past max_total_rows raises."""
29+
reader = MagicMock()
30+
writer = MagicMock()
31+
p = DqliteProtocol(reader, writer, timeout=5.0, max_total_rows=5)
32+
33+
initial = _make_rows_response([[1], [2], [3]], has_more=True)
34+
# This next frame would bring total to 6, exceeding the cap of 5.
35+
over_cap = _make_rows_response([[4], [5], [6]], has_more=False)
36+
37+
p._read_continuation = AsyncMock(return_value=over_cap) # type: ignore[method-assign]
38+
39+
async def run() -> None:
40+
await p._drain_continuations(initial, deadline=999999.0)
41+
42+
with pytest.raises(ProtocolError, match="max_total_rows"):
43+
asyncio.run(run())
44+
45+
def test_exactly_at_cap_does_not_raise(self) -> None:
46+
"""Hitting the cap exactly is fine; only exceeding raises."""
47+
reader = MagicMock()
48+
writer = MagicMock()
49+
p = DqliteProtocol(reader, writer, timeout=5.0, max_total_rows=5)
50+
51+
initial = _make_rows_response([[1], [2], [3]], has_more=True)
52+
at_cap = _make_rows_response([[4], [5]], has_more=False) # total: 5
53+
54+
p._read_continuation = AsyncMock(return_value=at_cap) # type: ignore[method-assign]
55+
56+
rows = asyncio.run(p._drain_continuations(initial, deadline=999999.0))
57+
assert len(rows) == 5
58+
59+
def test_none_disables_cap(self) -> None:
60+
"""max_total_rows=None means the cap never fires."""
61+
reader = MagicMock()
62+
writer = MagicMock()
63+
p = DqliteProtocol(reader, writer, timeout=5.0, max_total_rows=None)
64+
65+
initial = _make_rows_response([[i] for i in range(100)], has_more=True)
66+
big = _make_rows_response([[i] for i in range(100, 10_000)], has_more=False)
67+
68+
p._read_continuation = AsyncMock(return_value=big) # type: ignore[method-assign]
69+
70+
rows = asyncio.run(p._drain_continuations(initial, deadline=999999.0))
71+
assert len(rows) == 10_000

tests/test_redirect_policy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,40 @@ def test_no_policy_means_any_redirect_accepted(self) -> None:
5252
result = asyncio.run(cc.find_leader())
5353
assert result == "anywhere.invalid:9001"
5454

55+
def test_empty_allowlist_rejects_all_redirects(self) -> None:
56+
"""Empty allowlist means every redirect fails; self-leader still
57+
works because that path bypasses the policy entirely (it's not a
58+
real redirect)."""
59+
store = MemoryNodeStore(["10.0.0.1:9001"])
60+
cc = ClusterClient(
61+
store,
62+
timeout=5.0,
63+
redirect_policy=allowlist_policy([]),
64+
)
65+
with (
66+
patch.object(cc, "_query_leader", new=AsyncMock(return_value="other:9001")),
67+
pytest.raises(ClusterError, match="rejected"),
68+
):
69+
asyncio.run(cc.find_leader())
70+
71+
def test_allowlist_accepts_iterator_input(self) -> None:
72+
"""The helper accepts any iterable; internally it materializes once
73+
into a set, so a generator is safe (no iterator-exhaustion trap on
74+
repeated calls)."""
75+
store = MemoryNodeStore(["10.0.0.1:9001"])
76+
cc = ClusterClient(
77+
store,
78+
timeout=5.0,
79+
redirect_policy=allowlist_policy(x for x in ["10.0.0.1:9001", "10.0.0.2:9001"]),
80+
)
81+
# First call — would have drained the generator already.
82+
with patch.object(cc, "_query_leader", new=AsyncMock(return_value="10.0.0.2:9001")):
83+
assert asyncio.run(cc.find_leader()) == "10.0.0.2:9001"
84+
# Second call still honors the allowlist (proves the set was
85+
# materialized up-front, not re-iterated).
86+
with patch.object(cc, "_query_leader", new=AsyncMock(return_value="10.0.0.2:9001")):
87+
assert asyncio.run(cc.find_leader()) == "10.0.0.2:9001"
88+
5589
def test_self_leader_bypasses_policy(self) -> None:
5690
"""If the queried node is the leader (returns its own address),
5791
the redirect policy doesn't apply — the address is already in the

0 commit comments

Comments
 (0)