Skip to content

Commit f7876cb

Browse files
Wrap client exceptions into PEP 249 hierarchy; serialize close
ISSUE-31 — cursor.execute used to let dqliteclient.exceptions.* leak through to the caller. PEP 249 requires all error classes to live under the module's own hierarchy. New _call_client helper wraps the client call, mapping: client.OperationalError → dbapi.OperationalError client.DqliteConnectionError → dbapi.OperationalError client.ClusterError → dbapi.OperationalError client.ProtocolError → dbapi.InterfaceError client.DataError → dbapi.DataError client.InterfaceError → dbapi.InterfaceError Original cause preserved via `from`. Applied at both sync and async cursor entry points. ISSUE-32 — AsyncConnection.close() now acquires _op_lock before tearing down the underlying protocol. A concurrent task mid-execute used to find the socket closed underneath it; close now waits for the in-flight operation to complete. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a81a639 commit f7876cb

File tree

5 files changed

+205
-11
lines changed

5 files changed

+205
-11
lines changed

src/dqlitedbapi/aio/connection.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,26 @@ async def connect(self) -> None:
7979
await self._ensure_connection()
8080

8181
async def close(self) -> None:
82-
"""Close the connection."""
82+
"""Close the connection.
83+
84+
Serializes with any in-flight operation via ``_op_lock`` so we
85+
never tear down the underlying protocol while another task is
86+
mid-execute/mid-commit — that races would leave the caller
87+
with mysterious "connection closed" errors mid-query.
88+
"""
8389
if self._closed:
8490
return
91+
# Set _closed first so any task waiting on the lock sees the
92+
# closed state as soon as it acquires. Then drain the current
93+
# in-flight op (if any) under the lock.
8594
self._closed = True
86-
if self._async_conn is not None:
87-
await self._async_conn.close()
88-
self._async_conn = None
95+
if self._async_conn is None:
96+
return
97+
_, op_lock = self._ensure_locks()
98+
async with op_lock:
99+
if self._async_conn is not None:
100+
await self._async_conn.close()
101+
self._async_conn = None
89102

90103
async def commit(self) -> None:
91104
"""Commit any pending transaction.

src/dqlitedbapi/aio/cursor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from collections.abc import Sequence
44
from typing import TYPE_CHECKING, Any
55

6-
from dqlitedbapi.cursor import _convert_params, _convert_row, _strip_leading_comments
6+
from dqlitedbapi.cursor import (
7+
_call_client,
8+
_convert_params,
9+
_convert_row,
10+
_strip_leading_comments,
11+
)
712
from dqlitedbapi.exceptions import InterfaceError
813

914
if TYPE_CHECKING:
@@ -82,7 +87,9 @@ async def execute(
8287
_, op_lock = self._connection._ensure_locks()
8388
async with op_lock:
8489
if is_query:
85-
columns, column_types, rows = await conn.query_raw_typed(operation, params)
90+
columns, column_types, rows = await _call_client(
91+
conn.query_raw_typed(operation, params)
92+
)
8693
self._description = [
8794
(
8895
name,
@@ -99,7 +106,7 @@ async def execute(
99106
self._row_index = 0
100107
self._rowcount = len(rows)
101108
else:
102-
last_id, affected = await conn.execute(operation, params)
109+
last_id, affected = await _call_client(conn.execute(operation, params))
103110
self._lastrowid = last_id
104111
self._rowcount = affected
105112
self._description = None

src/dqlitedbapi/cursor.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,54 @@
11
"""PEP 249 Cursor implementation for dqlite."""
22

3-
from collections.abc import Callable, Mapping, Sequence
3+
from collections.abc import Callable, Coroutine, Mapping, Sequence
44
from typing import TYPE_CHECKING, Any
55

6+
import dqliteclient.exceptions as _client_exc
67
from dqlitewire.constants import ValueType
78

8-
from dqlitedbapi.exceptions import InterfaceError, ProgrammingError
9+
from dqlitedbapi.exceptions import (
10+
DataError,
11+
InterfaceError,
12+
OperationalError,
13+
ProgrammingError,
14+
)
15+
from dqlitedbapi.exceptions import (
16+
InterfaceError as _DbapiInterfaceError,
17+
)
918
from dqlitedbapi.types import (
1019
_convert_bind_param,
1120
_datetime_from_iso8601,
1221
_datetime_from_unixtime,
1322
)
1423

24+
25+
async def _call_client(coro: Coroutine[Any, Any, Any]) -> Any:
26+
"""Await a client-layer coroutine, mapping its exceptions into the
27+
PEP 249 hierarchy. Preserves the original via ``from``.
28+
29+
Mapping:
30+
client.OperationalError → dbapi.OperationalError (same code/msg)
31+
client.DqliteConnectionError → dbapi.OperationalError (network flavor)
32+
client.ClusterError → dbapi.OperationalError
33+
client.ProtocolError → dbapi.InterfaceError
34+
client.DataError → dbapi.DataError
35+
client.InterfaceError → dbapi.InterfaceError
36+
"""
37+
try:
38+
return await coro
39+
except _client_exc.OperationalError as e:
40+
raise OperationalError(str(e)) from e
41+
except _client_exc.DqliteConnectionError as e:
42+
raise OperationalError(str(e)) from e
43+
except _client_exc.ClusterError as e:
44+
raise OperationalError(str(e)) from e
45+
except _client_exc.ProtocolError as e:
46+
raise _DbapiInterfaceError(str(e)) from e
47+
except _client_exc.DataError as e:
48+
raise DataError(str(e)) from e
49+
except _client_exc.InterfaceError as e:
50+
raise _DbapiInterfaceError(str(e)) from e
51+
1552
if TYPE_CHECKING:
1653
from dqlitedbapi.connection import Connection
1754

@@ -165,7 +202,9 @@ async def _execute_async(self, operation: str, parameters: Sequence[Any] | None
165202
)
166203

167204
if is_query:
168-
columns, column_types, rows = await conn.query_raw_typed(operation, params)
205+
columns, column_types, rows = await _call_client(
206+
conn.query_raw_typed(operation, params)
207+
)
169208
self._description = [
170209
(
171210
name,
@@ -182,7 +221,7 @@ async def _execute_async(self, operation: str, parameters: Sequence[Any] | None
182221
self._row_index = 0
183222
self._rowcount = len(rows)
184223
else:
185-
last_id, affected = await conn.execute(operation, params)
224+
last_id, affected = await _call_client(conn.execute(operation, params))
186225
self._lastrowid = last_id
187226
self._rowcount = affected
188227
self._description = None

tests/test_async_close_race.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""AsyncConnection.close serializes with in-flight operations (ISSUE-32).
2+
3+
Previously close() called await self._async_conn.close() without
4+
acquiring _op_lock; a concurrent task mid-execute would find the
5+
protocol torn down underneath it. Now close() acquires the lock so
6+
in-flight operations drain cleanly before the socket is closed.
7+
"""
8+
9+
import asyncio
10+
from unittest.mock import AsyncMock, patch
11+
12+
import pytest
13+
14+
from dqlitedbapi.aio.connection import AsyncConnection
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_close_waits_for_in_flight_execute() -> None:
19+
"""If task A is mid-execute (holding _op_lock), task B's close
20+
awaits that task before tearing down the protocol."""
21+
conn = AsyncConnection("localhost:19001", database="x")
22+
23+
order: list[str] = []
24+
25+
async def slow_execute(_sql: str, _params: object) -> tuple[int, int]:
26+
order.append("execute:start")
27+
await asyncio.sleep(0.05)
28+
order.append("execute:end")
29+
return (0, 0)
30+
31+
async def fake_query_raw_typed(_sql: str, _params: object) -> tuple[list, list, list]:
32+
return ([], [], [])
33+
34+
with patch("dqlitedbapi.aio.connection.DqliteConnection") as MockDqliteConn:
35+
mock_instance = AsyncMock()
36+
mock_instance.connect = AsyncMock()
37+
mock_instance.execute = slow_execute
38+
mock_instance.query_raw_typed = fake_query_raw_typed
39+
40+
async def fake_close() -> None:
41+
order.append("close:start")
42+
order.append("close:end")
43+
44+
mock_instance.close = fake_close
45+
MockDqliteConn.return_value = mock_instance
46+
47+
await conn.connect()
48+
49+
async def run_execute() -> None:
50+
cursor = conn.cursor()
51+
await cursor.execute("INSERT INTO t VALUES (1)")
52+
53+
async def run_close() -> None:
54+
await asyncio.sleep(0.01) # let execute start first
55+
await conn.close()
56+
57+
await asyncio.gather(run_execute(), run_close())
58+
59+
# execute must complete before close starts tearing down.
60+
assert order == [
61+
"execute:start",
62+
"execute:end",
63+
"close:start",
64+
"close:end",
65+
], f"actual order: {order}"

tests/test_exception_mapping.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""PEP 249 exception wrapping at the cursor layer (ISSUE-31).
2+
3+
The underlying client raises ``dqliteclient.exceptions.*``. PEP 249
4+
requires a specific exception hierarchy under the DBAPI module, so the
5+
cursor must translate those to ``dqlitedbapi.exceptions.*`` when they
6+
surface to the user.
7+
"""
8+
9+
import asyncio
10+
from unittest.mock import AsyncMock, MagicMock
11+
12+
import dqliteclient.exceptions as client_exc
13+
import pytest
14+
15+
import dqlitedbapi.exceptions as dbapi_exc
16+
from dqlitedbapi.cursor import Cursor
17+
18+
19+
def _cursor_with_async_conn_raising(exc: Exception) -> Cursor:
20+
"""Build a Cursor whose underlying async conn raises ``exc`` on every
21+
query/execute. Bypasses the event-loop thread by running the coroutine
22+
in a fresh event loop."""
23+
mock_async_conn = AsyncMock()
24+
mock_async_conn.query_raw_typed = AsyncMock(side_effect=exc)
25+
mock_async_conn.execute = AsyncMock(side_effect=exc)
26+
27+
mock_conn = MagicMock()
28+
29+
async def get_async_conn() -> AsyncMock:
30+
return mock_async_conn
31+
32+
mock_conn._get_async_connection = get_async_conn
33+
34+
def run_sync(coro: object) -> object:
35+
return asyncio.new_event_loop().run_until_complete(coro) # type: ignore[arg-type]
36+
37+
mock_conn._run_sync = run_sync
38+
return Cursor(mock_conn)
39+
40+
41+
class TestExceptionWrapping:
42+
def test_client_operational_error_becomes_dbapi_operational_error(self) -> None:
43+
c = _cursor_with_async_conn_raising(client_exc.OperationalError(1, "boom"))
44+
with pytest.raises(dbapi_exc.OperationalError, match="boom"):
45+
c.execute("SELECT 1")
46+
47+
def test_client_connection_error_becomes_operational_error(self) -> None:
48+
c = _cursor_with_async_conn_raising(client_exc.DqliteConnectionError("no route"))
49+
with pytest.raises(dbapi_exc.OperationalError, match="no route"):
50+
c.execute("SELECT 1")
51+
52+
def test_client_protocol_error_becomes_interface_error(self) -> None:
53+
c = _cursor_with_async_conn_raising(client_exc.ProtocolError("bad frame"))
54+
with pytest.raises(dbapi_exc.InterfaceError, match="bad frame"):
55+
c.execute("SELECT 1")
56+
57+
def test_client_data_error_becomes_data_error(self) -> None:
58+
c = _cursor_with_async_conn_raising(client_exc.DataError("bad param"))
59+
with pytest.raises(dbapi_exc.DataError, match="bad param"):
60+
c.execute("INSERT INTO t VALUES (?)", [object()])
61+
62+
def test_chained_cause_preserved(self) -> None:
63+
original = client_exc.OperationalError(1, "original")
64+
c = _cursor_with_async_conn_raising(original)
65+
try:
66+
c.execute("SELECT 1")
67+
except dbapi_exc.OperationalError as e:
68+
assert e.__cause__ is original
69+
else:
70+
pytest.fail("expected OperationalError")

0 commit comments

Comments
 (0)