Skip to content

Commit 467a411

Browse files
Forward governors and validate timeout in the async PEP 249 factories
aio.connect() / aio.aconnect() accepted only address / database / timeout and silently dropped the three DoS-governor parameters that AsyncConnection already supports: max_total_rows, max_continuation_frames, and trust_server_heartbeat. This mirrored the sync-side gap already closed earlier, but the async side was missed. Worse, the SQLAlchemy async dialect forwards URL query parameters straight into aio.connect() via loaded_dbapi.connect(**kwargs), so a URL like dqlite+aio://host/db?max_total_rows=500 raised TypeError: unexpected keyword argument 'max_total_rows' at engine-connect time rather than applying the cap. Add the three kwargs to both factories with the same defaults as the sync side, plus a math.isfinite timeout check at the public boundary so errors surface at call-site in both sync and async APIs. Regression tests cover: all governors forwarded, defaults, None caps, and invalid timeout values. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 582df33 commit 467a411

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

src/dqlitedbapi/aio/__init__.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Async PEP 249-style interface for dqlite."""
22

3+
import math
4+
35
from dqlitedbapi import __version__
46
from dqlitedbapi.aio.connection import AsyncConnection
57
from dqlitedbapi.aio.cursor import AsyncCursor
@@ -91,6 +93,9 @@ def connect(
9193
*,
9294
database: str = "default",
9395
timeout: float = 10.0,
96+
max_total_rows: int | None = 10_000_000,
97+
max_continuation_frames: int | None = 100_000,
98+
trust_server_heartbeat: bool = False,
9499
) -> AsyncConnection:
95100
"""Create a dqlite connection (connects lazily on first use).
96101
@@ -101,19 +106,40 @@ def connect(
101106
Args:
102107
address: Node address in "host:port" format
103108
database: Database name to open
104-
timeout: Connection timeout in seconds
109+
timeout: Connection timeout in seconds — must be a positive
110+
finite number. 0, negatives, and non-finite values are
111+
rejected here rather than silently passed through.
112+
max_total_rows: Cumulative row cap across continuation frames
113+
for a single query. Forwarded to the underlying
114+
AsyncConnection. None disables the cap.
115+
max_continuation_frames: Per-query continuation-frame cap.
116+
Forwarded to the underlying AsyncConnection.
117+
trust_server_heartbeat: Let the server-advertised heartbeat
118+
widen the per-read deadline. Default False.
105119
106120
Returns:
107121
An AsyncConnection object
108122
"""
109-
return AsyncConnection(address, database=database, timeout=timeout)
123+
if not math.isfinite(timeout) or timeout <= 0:
124+
raise ProgrammingError(f"timeout must be a positive finite number, got {timeout}")
125+
return AsyncConnection(
126+
address,
127+
database=database,
128+
timeout=timeout,
129+
max_total_rows=max_total_rows,
130+
max_continuation_frames=max_continuation_frames,
131+
trust_server_heartbeat=trust_server_heartbeat,
132+
)
110133

111134

112135
async def aconnect(
113136
address: str,
114137
*,
115138
database: str = "default",
116139
timeout: float = 10.0,
140+
max_total_rows: int | None = 10_000_000,
141+
max_continuation_frames: int | None = 100_000,
142+
trust_server_heartbeat: bool = False,
117143
) -> AsyncConnection:
118144
"""Connect to a dqlite database asynchronously.
119145
@@ -122,11 +148,24 @@ async def aconnect(
122148
Args:
123149
address: Node address in "host:port" format
124150
database: Database name to open
125-
timeout: Connection timeout in seconds
151+
timeout: Connection timeout in seconds — must be positive and finite.
152+
max_total_rows: Cumulative row cap across continuation frames.
153+
max_continuation_frames: Per-query continuation-frame cap.
154+
trust_server_heartbeat: Let the server-advertised heartbeat
155+
widen the per-read deadline.
126156
127157
Returns:
128158
A connected AsyncConnection object
129159
"""
130-
conn = AsyncConnection(address, database=database, timeout=timeout)
160+
if not math.isfinite(timeout) or timeout <= 0:
161+
raise ProgrammingError(f"timeout must be a positive finite number, got {timeout}")
162+
conn = AsyncConnection(
163+
address,
164+
database=database,
165+
timeout=timeout,
166+
max_total_rows=max_total_rows,
167+
max_continuation_frames=max_continuation_frames,
168+
trust_server_heartbeat=trust_server_heartbeat,
169+
)
131170
await conn.connect()
132171
return conn

tests/test_aio_module_attributes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Tests for async module PEP 249 attributes and exports."""
22

3+
import pytest
4+
35
from dqlitedbapi import aio
6+
from dqlitedbapi.exceptions import ProgrammingError
47

58

69
class TestAioModuleAttributes:
@@ -30,3 +33,48 @@ def test_type_objects_exported(self) -> None:
3033
assert aio.NUMBER == "INTEGER"
3134
assert aio.DATETIME == "DATE"
3235
assert aio.ROWID == "ROWID"
36+
37+
38+
class TestAioConnectForwardsGovernors:
39+
def test_aio_connect_forwards_all_governors(self) -> None:
40+
conn = aio.connect(
41+
"localhost:19001",
42+
max_total_rows=500,
43+
max_continuation_frames=7,
44+
trust_server_heartbeat=True,
45+
)
46+
assert conn._max_total_rows == 500
47+
assert conn._max_continuation_frames == 7
48+
assert conn._trust_server_heartbeat is True
49+
50+
def test_aio_connect_uses_default_governors(self) -> None:
51+
conn = aio.connect("localhost:19001")
52+
assert conn._max_total_rows == 10_000_000
53+
assert conn._max_continuation_frames == 100_000
54+
assert conn._trust_server_heartbeat is False
55+
56+
def test_aio_connect_accepts_none_caps(self) -> None:
57+
conn = aio.connect(
58+
"localhost:19001",
59+
max_total_rows=None,
60+
max_continuation_frames=None,
61+
)
62+
assert conn._max_total_rows is None
63+
assert conn._max_continuation_frames is None
64+
65+
66+
class TestAioConnectTimeoutValidation:
67+
@pytest.mark.parametrize("bad", [0, -1, float("nan"), float("inf"), float("-inf")])
68+
def test_connect_rejects_non_positive_or_non_finite(self, bad: float) -> None:
69+
with pytest.raises(ProgrammingError, match="timeout must be a positive finite number"):
70+
aio.connect("localhost:19001", timeout=bad)
71+
72+
@pytest.mark.parametrize("bad", [0, -1, float("nan"), float("inf"), float("-inf")])
73+
def test_aconnect_rejects_non_positive_or_non_finite(self, bad: float) -> None:
74+
import asyncio
75+
76+
async def run() -> None:
77+
with pytest.raises(ProgrammingError, match="timeout must be a positive finite number"):
78+
await aio.aconnect("localhost:19001", timeout=bad)
79+
80+
asyncio.run(run())

0 commit comments

Comments
 (0)