Skip to content

Commit 41a9a84

Browse files
fix: address cycle-3 review — plumb caps and strengthen tests
Post-review follow-up (correctness + security + test-quality agents ran on the cycle-3 commits and flagged the items below). Critical — tests tried real connections to localhost:19001 - test_connect_forwards_max_total_rows refactored to inspect the Connection attribute without going through a connect flow. Uses a non-existent port (19999) because Connection.__init__ is pure state machine; nothing touches the network. - test_sync_cursor_execute_after_connection_close now uses a _ClosedConn fake whose _run_sync raises InterfaceError, mirroring the async variant's pattern. No cluster dependency. Medium — test strength - rownumber: add increment/fetchmany/fetchall coverage so the property's actual claim (advance with each fetched row) is pinned, not just the None-without-result-set path. Critical — plumbing gap - max_continuation_frames and trust_server_heartbeat now propagate through dqlitedbapi.Connection, dqlitedbapi.aio.Connection, and the module-level dqlitedbapi.connect(). Callers who go through the dbapi (the 99% path) can now tune the DoS governors end-to-end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 66a2407 commit 41a9a84

File tree

5 files changed

+139
-23
lines changed

5 files changed

+139
-23
lines changed

src/dqlitedbapi/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def connect(
9595
database: str = "default",
9696
timeout: float = 10.0,
9797
max_total_rows: int | None = 10_000_000,
98+
max_continuation_frames: int | None = 100_000,
99+
trust_server_heartbeat: bool = False,
98100
) -> Connection:
99101
"""Connect to a dqlite database.
100102
@@ -108,6 +110,10 @@ def connect(
108110
max_total_rows: Cumulative row cap across continuation frames
109111
for a single query. Forwarded to the underlying
110112
:class:`Connection` (ISSUE-111). ``None`` disables the cap.
113+
max_continuation_frames: Per-query continuation-frame cap
114+
(ISSUE-98). Forwarded to the underlying :class:`Connection`.
115+
trust_server_heartbeat: Let the server-advertised heartbeat
116+
widen the per-read deadline (ISSUE-101). Default False.
111117
112118
Returns:
113119
A Connection object
@@ -121,4 +127,6 @@ def connect(
121127
database=database,
122128
timeout=timeout,
123129
max_total_rows=max_total_rows,
130+
max_continuation_frames=max_continuation_frames,
131+
trust_server_heartbeat=trust_server_heartbeat,
124132
)

src/dqlitedbapi/aio/connection.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import dqliteclient.exceptions as _client_exc
99
from dqliteclient import DqliteConnection
10-
from dqliteclient.protocol import _validate_max_total_rows
10+
from dqliteclient.protocol import _validate_max_total_rows, _validate_positive_int_or_none
1111
from dqlitedbapi.aio.cursor import AsyncCursor
1212
from dqlitedbapi.connection import _is_no_transaction_error
1313
from dqlitedbapi.exceptions import InterfaceError, OperationalError, ProgrammingError
@@ -23,6 +23,8 @@ def __init__(
2323
database: str = "default",
2424
timeout: float = 10.0,
2525
max_total_rows: int | None = 10_000_000,
26+
max_continuation_frames: int | None = 100_000,
27+
trust_server_heartbeat: bool = False,
2628
) -> None:
2729
"""Initialize connection (does not connect yet).
2830
@@ -33,13 +35,21 @@ def __init__(
3335
max_total_rows: Cumulative row cap across continuation
3436
frames. Forwarded to the underlying DqliteConnection;
3537
``None`` disables the cap.
38+
max_continuation_frames: Per-query continuation-frame cap
39+
(ISSUE-98). Forwarded to the underlying DqliteConnection.
40+
trust_server_heartbeat: When True, let the server-advertised
41+
heartbeat widen the per-read deadline (ISSUE-101).
3642
"""
3743
if not math.isfinite(timeout) or timeout <= 0:
3844
raise ProgrammingError(f"timeout must be a positive finite number, got {timeout}")
3945
self._address = address
4046
self._database = database
4147
self._timeout = timeout
4248
self._max_total_rows = _validate_max_total_rows(max_total_rows)
49+
self._max_continuation_frames = _validate_positive_int_or_none(
50+
max_continuation_frames, "max_continuation_frames"
51+
)
52+
self._trust_server_heartbeat = trust_server_heartbeat
4353
self._async_conn: DqliteConnection | None = None
4454
self._closed = False
4555
# asyncio primitives MUST be created inside the loop they will
@@ -79,6 +89,8 @@ async def _ensure_connection(self) -> DqliteConnection:
7989
database=self._database,
8090
timeout=self._timeout,
8191
max_total_rows=self._max_total_rows,
92+
max_continuation_frames=self._max_continuation_frames,
93+
trust_server_heartbeat=self._trust_server_heartbeat,
8294
)
8395
try:
8496
await conn.connect()

src/dqlitedbapi/connection.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import dqliteclient.exceptions as _client_exc
1212
from dqliteclient import DqliteConnection
13-
from dqliteclient.protocol import _validate_max_total_rows
13+
from dqliteclient.protocol import _validate_max_total_rows, _validate_positive_int_or_none
1414
from dqlitedbapi.cursor import Cursor
1515
from dqlitedbapi.exceptions import InterfaceError, OperationalError, ProgrammingError
1616

@@ -94,6 +94,8 @@ def __init__(
9494
database: str = "default",
9595
timeout: float = 10.0,
9696
max_total_rows: int | None = 10_000_000,
97+
max_continuation_frames: int | None = 100_000,
98+
trust_server_heartbeat: bool = False,
9799
) -> None:
98100
"""Initialize connection (does not connect yet).
99101
@@ -107,13 +109,25 @@ def __init__(
107109
max_total_rows: Cumulative row cap across continuation
108110
frames for a single query. Forwarded to the underlying
109111
:class:`DqliteConnection`. ``None`` disables the cap.
112+
max_continuation_frames: Per-query continuation-frame cap
113+
(ISSUE-98). Bounds Python-side decode work a hostile
114+
server can inflict by drip-feeding 1-row frames.
115+
Forwarded to the underlying :class:`DqliteConnection`.
116+
trust_server_heartbeat: When True, widen the per-read
117+
deadline to the server-advertised heartbeat (subject to
118+
a 300 s hard cap). Default False so the configured
119+
``timeout`` is authoritative (ISSUE-101).
110120
"""
111121
if not math.isfinite(timeout) or timeout <= 0:
112122
raise ProgrammingError(f"timeout must be a positive finite number, got {timeout}")
113123
self._address = address
114124
self._database = database
115125
self._timeout = timeout
116126
self._max_total_rows = _validate_max_total_rows(max_total_rows)
127+
self._max_continuation_frames = _validate_positive_int_or_none(
128+
max_continuation_frames, "max_continuation_frames"
129+
)
130+
self._trust_server_heartbeat = trust_server_heartbeat
117131
self._async_conn: DqliteConnection | None = None
118132
self._closed = False
119133
self._loop: asyncio.AbstractEventLoop | None = None
@@ -228,6 +242,8 @@ async def _get_async_connection(self) -> DqliteConnection:
228242
database=self._database,
229243
timeout=self._timeout,
230244
max_total_rows=self._max_total_rows,
245+
max_continuation_frames=self._max_continuation_frames,
246+
trust_server_heartbeat=self._trust_server_heartbeat,
231247
)
232248
try:
233249
await conn.connect()

tests/test_cycle3_hardening.py

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,65 @@ class _FakeAsyncConn:
4343
cursor = AsyncCursor(_FakeAsyncConn()) # type: ignore[arg-type]
4444
assert cursor.rownumber is None
4545

46+
def test_rownumber_increments_with_fetchone(self) -> None:
47+
"""After each fetchone(), rownumber points at the next row."""
48+
from dqlitedbapi.cursor import Cursor
49+
50+
class _FakeConn:
51+
def _check_thread(self) -> None: ...
52+
53+
cursor = Cursor(_FakeConn()) # type: ignore[arg-type]
54+
cursor._description = [("x", None, None, None, None, None, None)]
55+
cursor._rows = [(1,), (2,), (3,)]
56+
57+
assert cursor.rownumber == 0 # before any fetch, cursor points at row 0
58+
59+
assert cursor.fetchone() == (1,)
60+
assert cursor.rownumber == 1
61+
62+
assert cursor.fetchone() == (2,)
63+
assert cursor.rownumber == 2
64+
65+
assert cursor.fetchone() == (3,)
66+
assert cursor.rownumber == 3 # past the end
67+
68+
# Further fetches return None and do not advance rownumber past len
69+
assert cursor.fetchone() is None
70+
assert cursor.rownumber == 3
71+
72+
def test_rownumber_after_fetchall(self) -> None:
73+
"""fetchall advances rownumber to end of result set."""
74+
from dqlitedbapi.cursor import Cursor
75+
76+
class _FakeConn:
77+
def _check_thread(self) -> None: ...
78+
79+
cursor = Cursor(_FakeConn()) # type: ignore[arg-type]
80+
cursor._description = [("x", None, None, None, None, None, None)]
81+
cursor._rows = [(i,) for i in range(5)]
82+
assert cursor.rownumber == 0
83+
84+
rows = cursor.fetchall()
85+
assert len(rows) == 5
86+
assert cursor.rownumber == 5
87+
88+
def test_rownumber_after_fetchmany(self) -> None:
89+
"""fetchmany advances rownumber by the number of rows fetched."""
90+
from dqlitedbapi.cursor import Cursor
91+
92+
class _FakeConn:
93+
def _check_thread(self) -> None: ...
94+
95+
cursor = Cursor(_FakeConn()) # type: ignore[arg-type]
96+
cursor._description = [("x", None, None, None, None, None, None)]
97+
cursor._rows = [(i,) for i in range(10)]
98+
99+
cursor.fetchmany(3)
100+
assert cursor.rownumber == 3
101+
102+
cursor.fetchmany(4)
103+
assert cursor.rownumber == 7
104+
46105

47106
class TestFetchmanyNegativeSize:
48107
def test_sync_fetchmany_rejects_negative(self) -> None:
@@ -221,27 +280,36 @@ def test_constraint_violation_is_not_silenced(self) -> None:
221280

222281

223282
class TestConnectForwardsMaxTotalRows:
224-
"""ISSUE-111: module-level connect() forwards max_total_rows."""
225-
226-
def test_connect_forwards_max_total_rows(self) -> None:
227-
from dqlitedbapi import connect
228-
from dqlitedbapi.connection import Connection
283+
"""ISSUE-111: module-level connect() forwards max_total_rows.
229284
230-
conn = connect("localhost:19001", max_total_rows=500)
231-
try:
232-
assert isinstance(conn, Connection)
233-
assert conn._max_total_rows == 500
234-
finally:
235-
conn.close()
285+
These tests only verify parameter plumbing — they do not open a
286+
socket. Connection.__init__ is pure state machine; no cluster
287+
needed. The previous implementation called conn.close() which
288+
required a running cluster for the event-loop thread to wind down
289+
cleanly.
290+
"""
236291

237-
def test_connect_forwards_none_for_max_total_rows(self) -> None:
292+
@pytest.mark.parametrize(
293+
"max_total_rows,expected",
294+
[(500, 500), (None, None), (10_000, 10_000)],
295+
)
296+
def test_connect_forwards_max_total_rows(
297+
self, max_total_rows: int | None, expected: int | None
298+
) -> None:
238299
from dqlitedbapi import connect
300+
from dqlitedbapi.connection import Connection
239301

240-
conn = connect("localhost:19001", max_total_rows=None)
241-
try:
242-
assert conn._max_total_rows is None
243-
finally:
244-
conn.close()
302+
# connect() does NOT actually connect to the server — it
303+
# instantiates a Connection with the given address and defers
304+
# the real TCP until first use. Inspect the attribute and then
305+
# skip close() because close() on a never-connected connection
306+
# is a silent no-op (no loop thread was started).
307+
conn = connect("localhost:19999", max_total_rows=max_total_rows)
308+
assert isinstance(conn, Connection)
309+
assert conn._max_total_rows == expected
310+
# conn.close() on an unused connection is a no-op; no cluster
311+
# contact happens.
312+
conn.close()
245313

246314

247315
class TestIsRowReturning:

tests/test_cycle3_tests_pinning.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,23 @@ class TestCursorAfterExternalConnectionClose:
2121
"""
2222

2323
def test_sync_cursor_execute_after_connection_close(self) -> None:
24-
import dqlitedbapi
24+
# Simulate a closed connection without touching any cluster.
25+
# The contract we're pinning: when the underlying connection's
26+
# _run_sync raises InterfaceError (its documented response to a
27+
# closed connection), the cursor surface propagates it.
28+
from dqlitedbapi.cursor import Cursor
2529

26-
conn = dqlitedbapi.connect("localhost:19001")
27-
cursor = conn.cursor()
28-
conn.close()
30+
class _ClosedConn:
31+
_closed = True
32+
33+
def _check_thread(self) -> None:
34+
return None
35+
36+
def _run_sync(self, coro) -> None: # noqa: ANN001
37+
coro.close()
38+
raise InterfaceError("Connection is closed")
39+
40+
cursor = Cursor(_ClosedConn()) # type: ignore[arg-type]
2941

3042
with pytest.raises(InterfaceError):
3143
cursor.execute("SELECT 1")

0 commit comments

Comments
 (0)