Skip to content

Commit dcba9a4

Browse files
Better async support
1 parent b36e98e commit dcba9a4

1 file changed

Lines changed: 149 additions & 7 deletions

File tree

src/sqlalchemydqlite/aio.py

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,135 @@
11
"""Async dqlite dialect for SQLAlchemy."""
22

3+
from collections import deque
4+
from collections.abc import Sequence
35
from typing import Any
46

57
from sqlalchemy import pool
6-
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
7-
from sqlalchemy.engine import URL
8+
from sqlalchemy.engine import URL, AdaptedConnection
9+
from sqlalchemy.engine.interfaces import DBAPIConnection
810
from sqlalchemy.pool import AsyncAdaptedQueuePool
11+
from sqlalchemy.util import await_only
912

13+
from sqlalchemydqlite.base import DqliteDialect
1014

11-
class DqliteDialect_aio(SQLiteDialect): # noqa: N801
15+
16+
class AsyncAdaptedCursor:
17+
"""Adapts an AsyncCursor for SQLAlchemy's greenlet-based async engine.
18+
19+
Eagerly fetches all rows during execute() within the greenlet context,
20+
then serves fetch* calls synchronously from the buffer. This matches
21+
the pattern used by SQLAlchemy's aiosqlite dialect.
22+
"""
23+
24+
server_side = False
25+
26+
def __init__(self, adapt_connection: "AsyncAdaptedConnection") -> None:
27+
self._adapt_connection = adapt_connection
28+
self._connection = adapt_connection._connection
29+
self.description: Any = None
30+
self.rowcount: int = -1
31+
self.lastrowid: int | None = None
32+
self.arraysize: int = 1
33+
self._rows: deque[Any] = deque()
34+
35+
async def _async_soft_close(self) -> None:
36+
return
37+
38+
def close(self) -> None:
39+
self._rows.clear()
40+
41+
def execute(self, operation: str, parameters: Any = None) -> Any:
42+
cursor = self._connection.cursor()
43+
if parameters is not None:
44+
await_only(cursor.execute(operation, parameters))
45+
else:
46+
await_only(cursor.execute(operation))
47+
48+
if cursor.description:
49+
self.description = cursor.description
50+
self.lastrowid = self.rowcount = -1
51+
self._rows = deque(await_only(cursor.fetchall()))
52+
else:
53+
self.description = None
54+
self.lastrowid = cursor.lastrowid
55+
self.rowcount = cursor.rowcount
56+
57+
await_only(cursor.close())
58+
59+
def executemany(self, operation: str, seq_of_parameters: Any) -> Any:
60+
cursor = self._connection.cursor()
61+
await_only(cursor.executemany(operation, seq_of_parameters))
62+
self.description = None
63+
self.lastrowid = cursor.lastrowid
64+
self.rowcount = cursor.rowcount
65+
await_only(cursor.close())
66+
67+
def fetchone(self) -> Any:
68+
if self._rows:
69+
return self._rows.popleft()
70+
return None
71+
72+
def fetchmany(self, size: int | None = None) -> Sequence[Any]:
73+
if size is None:
74+
size = self.arraysize
75+
return [self._rows.popleft() for _ in range(min(size, len(self._rows)))]
76+
77+
def fetchall(self) -> Sequence[Any]:
78+
retval = list(self._rows)
79+
self._rows.clear()
80+
return retval
81+
82+
def setinputsizes(self, *inputsizes: Any) -> None:
83+
pass
84+
85+
def setoutputsize(self, size: int, column: int | None = None) -> None:
86+
pass
87+
88+
def __iter__(self) -> Any:
89+
while self._rows:
90+
yield self._rows.popleft()
91+
92+
def __next__(self) -> Any:
93+
row = self.fetchone()
94+
if row is None:
95+
raise StopIteration
96+
return row
97+
98+
99+
class AsyncAdaptedConnection(AdaptedConnection):
100+
"""Adapts an AsyncConnection for SQLAlchemy's greenlet-based async engine.
101+
102+
Provides sync-looking methods that internally use await_only() to
103+
bridge to the underlying async connection within SQLAlchemy's
104+
greenlet context.
105+
"""
106+
107+
def __init__(self, connection: Any) -> None:
108+
self._connection = connection
109+
110+
def cursor(self) -> AsyncAdaptedCursor:
111+
return AsyncAdaptedCursor(self)
112+
113+
def commit(self) -> None:
114+
await_only(self._connection.commit())
115+
116+
def rollback(self) -> None:
117+
await_only(self._connection.rollback())
118+
119+
def close(self) -> None:
120+
await_only(self._connection.close())
121+
122+
123+
class DqliteDialect_aio(DqliteDialect): # noqa: N801
12124
"""Async SQLAlchemy dialect for dqlite.
13125
14126
Use with SQLAlchemy's async engine:
15127
create_async_engine("dqlite+aio://host:port/database")
16128
"""
17129

18-
name = "dqlite"
19130
driver = "dqlitedbapi_aio"
20131
is_async = True
21-
22-
# dqlite uses qmark parameter style
23-
paramstyle = "qmark"
132+
supports_statement_cache = True
24133

25134
@classmethod
26135
def get_pool_class(cls, url: URL) -> type[pool.Pool]:
@@ -32,6 +141,11 @@ def import_dbapi(cls) -> Any:
32141

33142
return aio
34143

144+
def connect(self, *cargs: Any, **cparams: Any) -> Any:
145+
"""Create and wrap an async connection."""
146+
raw_conn = self.loaded_dbapi.connect(*cargs, **cparams)
147+
return AsyncAdaptedConnection(raw_conn)
148+
35149
def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]:
36150
"""Create connection arguments from URL.
37151
@@ -48,6 +162,34 @@ def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]:
48162
"database": database,
49163
}
50164

165+
def do_rollback(self, dbapi_connection: DBAPIConnection) -> None:
166+
"""Rollback the current transaction."""
167+
try:
168+
dbapi_connection.rollback()
169+
except Exception as e:
170+
if "no transaction is active" not in str(e):
171+
raise
172+
173+
def do_commit(self, dbapi_connection: DBAPIConnection) -> None:
174+
"""Commit the current transaction."""
175+
try:
176+
dbapi_connection.commit()
177+
except Exception as e:
178+
if "no transaction is active" not in str(e):
179+
raise
180+
181+
def _get_server_version_info(self, connection: Any) -> tuple[int, ...]:
182+
"""Return the server version as a tuple."""
183+
cursor = connection.connection.dbapi_connection.cursor()
184+
cursor.execute("SELECT sqlite_version()")
185+
row = cursor.fetchone()
186+
cursor.close()
187+
188+
if row:
189+
version_str = row[0]
190+
return tuple(int(x) for x in version_str.split("."))
191+
return (3, 0, 0)
192+
51193
def get_driver_connection(self, connection: Any) -> Any:
52194
"""Return the driver-level connection."""
53195
return connection

0 commit comments

Comments
 (0)