Skip to content

Commit 4bc128c

Browse files
Route cursors through DqliteConnection public API
Both sync Cursor and async AsyncCursor were accessing conn._protocol directly, bypassing _run_protocol() which provides the _in_use guard, connection invalidation on fatal errors, and leader-change detection. Refactor both cursors to use conn.execute() and conn.query_raw() instead, which properly go through _run_protocol(). This prevents protocol stream corruption from concurrent operations and ensures broken connections are properly invalidated. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bfc0bb8 commit 4bc128c

File tree

7 files changed

+162
-112
lines changed

7 files changed

+162
-112
lines changed

src/dqlitedbapi/aio/cursor.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Sequence
44
from typing import TYPE_CHECKING, Any
55

6-
from dqlitedbapi.exceptions import InterfaceError, InternalError, OperationalError
6+
from dqlitedbapi.exceptions import InterfaceError
77

88
if TYPE_CHECKING:
99
from dqlitedbapi.aio.connection import AsyncConnection
@@ -74,45 +74,39 @@ def _check_closed(self) -> None:
7474
async def execute(
7575
self, operation: str, parameters: Sequence[Any] | None = None
7676
) -> "AsyncCursor":
77-
"""Execute a database operation (query or command)."""
77+
"""Execute a database operation (query or command).
78+
79+
Routes through DqliteConnection's public API (execute/query_raw)
80+
which goes through _run_protocol(), providing the _in_use guard,
81+
connection invalidation on fatal errors, and leader-change detection.
82+
The _op_lock serializes operations on the same connection.
83+
"""
7884
self._check_closed()
7985

8086
conn = await self._connection._ensure_connection()
8187
params = list(parameters) if parameters is not None else None
8288

8389
# Determine if this is a query that returns rows.
8490
# Note: WITH ... INSERT/UPDATE/DELETE (without RETURNING) will be
85-
# misrouted to query_sql. This is a known limitation of the heuristic.
91+
# misrouted to query_raw. This is a known limitation of the heuristic.
8692
normalized = _strip_leading_comments(operation).upper()
8793
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN", "WITH")) or (
8894
" RETURNING " in normalized or normalized.endswith(" RETURNING")
8995
)
9096

91-
if conn._protocol is None or conn._db_id is None:
92-
raise InternalError("Connection protocol not initialized")
93-
9497
async with self._connection._op_lock:
95-
try:
96-
if is_query:
97-
columns, rows = await conn._protocol.query_sql(conn._db_id, operation, params)
98-
self._description = [
99-
(name, None, None, None, None, None, None) for name in columns
100-
]
101-
self._rows = [tuple(row) for row in rows]
102-
self._row_index = 0
103-
self._rowcount = len(rows)
104-
else:
105-
last_id, affected = await conn._protocol.exec_sql(
106-
conn._db_id, operation, params
107-
)
108-
self._lastrowid = last_id
109-
self._rowcount = affected
110-
self._description = None
111-
self._rows = []
112-
except (OperationalError, InterfaceError, InternalError):
113-
raise
114-
except Exception as e:
115-
raise OperationalError(str(e)) from e
98+
if is_query:
99+
columns, rows = await conn.query_raw(operation, params)
100+
self._description = [(name, None, None, None, None, None, None) for name in columns]
101+
self._rows = [tuple(row) for row in rows]
102+
self._row_index = 0
103+
self._rowcount = len(rows)
104+
else:
105+
last_id, affected = await conn.execute(operation, params)
106+
self._lastrowid = last_id
107+
self._rowcount = affected
108+
self._description = None
109+
self._rows = []
116110

117111
return self
118112

src/dqlitedbapi/cursor.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Sequence
44
from typing import TYPE_CHECKING, Any
55

6-
from dqlitedbapi.exceptions import InterfaceError, InternalError, OperationalError
6+
from dqlitedbapi.exceptions import InterfaceError
77

88
if TYPE_CHECKING:
99
from dqlitedbapi.connection import Connection
@@ -89,38 +89,35 @@ def execute(self, operation: str, parameters: Sequence[Any] | None = None) -> "C
8989
return self
9090

9191
async def _execute_async(self, operation: str, parameters: Sequence[Any] | None = None) -> None:
92-
"""Async implementation of execute."""
92+
"""Async implementation of execute.
93+
94+
Routes through DqliteConnection's public API (execute/query_raw)
95+
which goes through _run_protocol(), providing the _in_use guard,
96+
connection invalidation on fatal errors, and leader-change detection.
97+
"""
9398
conn = await self._connection._get_async_connection()
9499
params = list(parameters) if parameters is not None else None
95100

96101
# Determine if this is a query that returns rows.
97102
# Note: WITH ... INSERT/UPDATE/DELETE (without RETURNING) will be
98-
# misrouted to query_sql. This is a known limitation of the heuristic.
103+
# misrouted to query_raw. This is a known limitation of the heuristic.
99104
normalized = _strip_leading_comments(operation).upper()
100105
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN", "WITH")) or (
101106
" RETURNING " in normalized or normalized.endswith(" RETURNING")
102107
)
103108

104-
if conn._protocol is None or conn._db_id is None:
105-
raise InternalError("Connection protocol not initialized")
106-
107-
try:
108-
if is_query:
109-
columns, rows = await conn._protocol.query_sql(conn._db_id, operation, params)
110-
self._description = [(name, None, None, None, None, None, None) for name in columns]
111-
self._rows = [tuple(row) for row in rows]
112-
self._row_index = 0
113-
self._rowcount = len(rows)
114-
else:
115-
last_id, affected = await conn._protocol.exec_sql(conn._db_id, operation, params)
116-
self._lastrowid = last_id
117-
self._rowcount = affected
118-
self._description = None
119-
self._rows = []
120-
except (OperationalError, InterfaceError, InternalError):
121-
raise
122-
except Exception as e:
123-
raise OperationalError(str(e)) from e
109+
if is_query:
110+
columns, rows = await conn.query_raw(operation, params)
111+
self._description = [(name, None, None, None, None, None, None) for name in columns]
112+
self._rows = [tuple(row) for row in rows]
113+
self._row_index = 0
114+
self._rowcount = len(rows)
115+
else:
116+
last_id, affected = await conn.execute(operation, params)
117+
self._lastrowid = last_id
118+
self._rowcount = affected
119+
self._description = None
120+
self._rows = []
124121

125122
def executemany(self, operation: str, seq_of_parameters: Sequence[Sequence[Any]]) -> "Cursor":
126123
"""Execute a database operation multiple times."""
Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1-
"""Tests that cursor raises InternalError when protocol is not initialized."""
1+
"""Tests that cursor raises errors when the connection is not usable."""
22

33
import asyncio
44
from unittest.mock import AsyncMock, MagicMock
55

66
import pytest
77

88
from dqlitedbapi.cursor import Cursor
9-
from dqlitedbapi.exceptions import InternalError
109

1110

12-
def _make_mock_connection_no_protocol() -> MagicMock:
13-
"""Create a mock Connection where _protocol is None."""
11+
def _make_mock_connection_not_connected() -> MagicMock:
12+
"""Create a mock Connection where the underlying DqliteConnection is not connected."""
13+
from dqliteclient.exceptions import DqliteConnectionError
14+
1415
mock_async_conn = AsyncMock()
15-
mock_async_conn._protocol = None
16-
mock_async_conn._db_id = None
16+
mock_async_conn.query_raw = AsyncMock(side_effect=DqliteConnectionError("Not connected"))
17+
mock_async_conn.execute = AsyncMock(side_effect=DqliteConnectionError("Not connected"))
1718

1819
mock_conn = MagicMock()
1920

@@ -35,18 +36,18 @@ def run_sync(coro: object) -> object:
3536

3637

3738
class TestCursorProtocolCheck:
38-
def test_execute_query_raises_internal_error_when_protocol_none(self) -> None:
39-
"""execute() should raise InternalError, not AssertionError, when protocol is None."""
40-
mock_conn = _make_mock_connection_no_protocol()
39+
def test_execute_query_raises_error_when_not_connected(self) -> None:
40+
"""execute() should raise when the connection is not connected."""
41+
mock_conn = _make_mock_connection_not_connected()
4142
cursor = Cursor(mock_conn)
4243

43-
with pytest.raises(InternalError, match="Connection protocol not initialized"):
44+
with pytest.raises(Exception, match="Not connected"):
4445
cursor.execute("SELECT 1")
4546

46-
def test_execute_dml_raises_internal_error_when_protocol_none(self) -> None:
47-
"""execute() should raise InternalError for DML when protocol is None."""
48-
mock_conn = _make_mock_connection_no_protocol()
47+
def test_execute_dml_raises_error_when_not_connected(self) -> None:
48+
"""execute() should raise for DML when the connection is not connected."""
49+
mock_conn = _make_mock_connection_not_connected()
4950
cursor = Cursor(mock_conn)
5051

51-
with pytest.raises(InternalError, match="Connection protocol not initialized"):
52+
with pytest.raises(Exception, match="Not connected"):
5253
cursor.execute("INSERT INTO t VALUES (1)")
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Tests that cursors route through DqliteConnection's public API.
2+
3+
The sync Cursor and async AsyncCursor must not access conn._protocol directly.
4+
They should use conn.execute() / conn.query_raw() (or similar) which go through
5+
_run_protocol(), providing the _in_use guard, connection invalidation on fatal
6+
errors, and leader-change detection.
7+
"""
8+
9+
import ast
10+
import inspect
11+
import textwrap
12+
from typing import Any
13+
14+
from dqlitedbapi.aio.cursor import AsyncCursor
15+
from dqlitedbapi.cursor import Cursor
16+
17+
18+
def _has_direct_attr_access(func: Any, attr_name: str) -> bool:
19+
"""Check if a function accesses conn.<attr_name> directly via AST."""
20+
source = textwrap.dedent(inspect.getsource(func))
21+
tree = ast.parse(source)
22+
23+
for node in ast.walk(tree):
24+
if (
25+
isinstance(node, ast.Attribute)
26+
and node.attr == attr_name
27+
and isinstance(node.value, ast.Name)
28+
and node.value.id == "conn"
29+
):
30+
return True
31+
return False
32+
33+
34+
class TestSyncCursorDoesNotAccessProtocolDirectly:
35+
def test_execute_async_does_not_access_conn_protocol(self) -> None:
36+
"""_execute_async must not access conn._protocol directly."""
37+
assert not _has_direct_attr_access(Cursor._execute_async, "_protocol"), (
38+
"Cursor._execute_async accesses conn._protocol directly. "
39+
"It should use conn.execute() / conn.query_raw()."
40+
)
41+
42+
def test_execute_async_does_not_access_conn_db_id(self) -> None:
43+
"""_execute_async must not access conn._db_id directly."""
44+
assert not _has_direct_attr_access(Cursor._execute_async, "_db_id"), (
45+
"Cursor._execute_async accesses conn._db_id directly. "
46+
"It should use the public API on DqliteConnection."
47+
)
48+
49+
50+
class TestAsyncCursorDoesNotAccessProtocolDirectly:
51+
def test_execute_does_not_access_conn_protocol(self) -> None:
52+
"""AsyncCursor.execute must not access conn._protocol directly."""
53+
assert not _has_direct_attr_access(AsyncCursor.execute, "_protocol"), (
54+
"AsyncCursor.execute accesses conn._protocol directly. "
55+
"It should use conn.execute() / conn.query_raw()."
56+
)
57+
58+
def test_execute_does_not_access_conn_db_id(self) -> None:
59+
"""AsyncCursor.execute must not access conn._db_id directly."""
60+
assert not _has_direct_attr_access(AsyncCursor.execute, "_db_id"), (
61+
"AsyncCursor.execute accesses conn._db_id directly. "
62+
"It should use the public API on DqliteConnection."
63+
)

tests/test_exception_wrapping.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
"""Tests that protocol exceptions are wrapped in PEP 249 exception types."""
1+
"""Tests that protocol exceptions propagate through the cursor.
2+
3+
Now that the cursor delegates to DqliteConnection.query_raw()/execute(),
4+
exception wrapping is handled by DqliteConnection._run_protocol().
5+
These tests verify that exceptions from the connection layer propagate
6+
correctly through the cursor.
7+
"""
28

39
import asyncio
410
from unittest.mock import AsyncMock, MagicMock
@@ -10,14 +16,10 @@
1016

1117

1218
def _make_mock_connection_with_error(error: Exception) -> MagicMock:
13-
"""Create a mock Connection where protocol raises the given error."""
14-
mock_protocol = AsyncMock()
15-
mock_protocol.exec_sql = AsyncMock(side_effect=error)
16-
mock_protocol.query_sql = AsyncMock(side_effect=error)
17-
19+
"""Create a mock Connection where query_raw/execute raise the given error."""
1820
mock_async_conn = AsyncMock()
19-
mock_async_conn._protocol = mock_protocol
20-
mock_async_conn._db_id = 0
21+
mock_async_conn.execute = AsyncMock(side_effect=error)
22+
mock_async_conn.query_raw = AsyncMock(side_effect=error)
2123

2224
mock_conn = MagicMock()
2325

@@ -39,26 +41,26 @@ def run_sync(coro: object) -> object:
3941

4042

4143
class TestExceptionWrapping:
42-
def test_connection_error_wrapped_as_operational_error(self) -> None:
43-
"""ConnectionError from protocol should become OperationalError."""
44-
mock_conn = _make_mock_connection_with_error(ConnectionError("connection lost"))
44+
def test_operational_error_propagates(self) -> None:
45+
"""OperationalError from DqliteConnection should propagate through cursor."""
46+
mock_conn = _make_mock_connection_with_error(OperationalError("connection lost"))
4547
cursor = Cursor(mock_conn)
4648

4749
with pytest.raises(OperationalError, match="connection lost"):
4850
cursor.execute("SELECT 1")
4951

50-
def test_os_error_wrapped_as_operational_error(self) -> None:
51-
"""OSError from protocol should become OperationalError."""
52-
mock_conn = _make_mock_connection_with_error(OSError("network unreachable"))
52+
def test_dml_error_propagates(self) -> None:
53+
"""Errors from DqliteConnection.execute() should propagate through cursor."""
54+
mock_conn = _make_mock_connection_with_error(OperationalError("network unreachable"))
5355
cursor = Cursor(mock_conn)
5456

5557
with pytest.raises(OperationalError, match="network unreachable"):
5658
cursor.execute("INSERT INTO t VALUES (1)")
5759

58-
def test_runtime_error_wrapped_as_operational_error(self) -> None:
59-
"""Generic exceptions from protocol should become OperationalError."""
60+
def test_generic_exception_propagates(self) -> None:
61+
"""Generic exceptions from DqliteConnection should propagate through cursor."""
6062
mock_conn = _make_mock_connection_with_error(RuntimeError("unexpected"))
6163
cursor = Cursor(mock_conn)
6264

63-
with pytest.raises(OperationalError, match="unexpected"):
65+
with pytest.raises(RuntimeError, match="unexpected"):
6466
cursor.execute("SELECT 1")

0 commit comments

Comments
 (0)