22
33import asyncio
44import contextlib
5- from collections .abc import AsyncIterator , Sequence
5+ from collections .abc import AsyncIterator , Awaitable , Callable , Sequence
66from contextlib import asynccontextmanager
77from typing import Any
88
@@ -206,16 +206,17 @@ def _invalidate(self) -> None:
206206 self ._protocol = None
207207 self ._db_id = None
208208
209- async def execute (self , sql : str , params : Sequence [ Any ] | None = None ) -> tuple [ int , int ] :
210- """Execute a SQL statement .
209+ async def _run_protocol [ T ] (self , fn : Callable [[ DqliteProtocol , int ], Awaitable [ T ]] ) -> T :
210+ """Run a protocol operation with standard error handling .
211211
212- Returns (last_insert_id, rows_affected).
212+ Handles connection guards (_check_in_use, _ensure_connected, _in_use),
213+ invalidates the connection on fatal errors, and resets _in_use in all cases.
213214 """
214215 self ._check_in_use ()
215216 protocol , db_id = self ._ensure_connected ()
216217 self ._in_use = True
217218 try :
218- return await protocol . exec_sql ( db_id , sql , params )
219+ return await fn ( protocol , db_id )
219220 except (DqliteConnectionError , ProtocolError ):
220221 self ._invalidate ()
221222 raise
@@ -229,46 +230,21 @@ async def execute(self, sql: str, params: Sequence[Any] | None = None) -> tuple[
229230 finally :
230231 self ._in_use = False
231232
233+ async def execute (self , sql : str , params : Sequence [Any ] | None = None ) -> tuple [int , int ]:
234+ """Execute a SQL statement.
235+
236+ Returns (last_insert_id, rows_affected).
237+ """
238+ return await self ._run_protocol (lambda p , db : p .exec_sql (db , sql , params ))
239+
232240 async def fetch (self , sql : str , params : Sequence [Any ] | None = None ) -> list [dict [str , Any ]]:
233241 """Execute a query and return results as list of dicts."""
234- self ._check_in_use ()
235- protocol , db_id = self ._ensure_connected ()
236- self ._in_use = True
237- try :
238- columns , rows = await protocol .query_sql (db_id , sql , params )
239- except (DqliteConnectionError , ProtocolError ):
240- self ._invalidate ()
241- raise
242- except OperationalError as e :
243- if e .code in _LEADER_ERROR_CODES :
244- self ._invalidate ()
245- raise
246- except BaseException :
247- self ._invalidate ()
248- raise
249- finally :
250- self ._in_use = False
242+ columns , rows = await self ._run_protocol (lambda p , db : p .query_sql (db , sql , params ))
251243 return [dict (zip (columns , row , strict = True )) for row in rows ]
252244
253245 async def fetchall (self , sql : str , params : Sequence [Any ] | None = None ) -> list [list [Any ]]:
254246 """Execute a query and return results as list of lists."""
255- self ._check_in_use ()
256- protocol , db_id = self ._ensure_connected ()
257- self ._in_use = True
258- try :
259- _ , rows = await protocol .query_sql (db_id , sql , params )
260- except (DqliteConnectionError , ProtocolError ):
261- self ._invalidate ()
262- raise
263- except OperationalError as e :
264- if e .code in _LEADER_ERROR_CODES :
265- self ._invalidate ()
266- raise
267- except BaseException :
268- self ._invalidate ()
269- raise
270- finally :
271- self ._in_use = False
247+ _ , rows = await self ._run_protocol (lambda p , db : p .query_sql (db , sql , params ))
272248 return rows
273249
274250 async def fetchone (
@@ -285,23 +261,7 @@ async def fetchone(
285261
286262 async def fetchval (self , sql : str , params : Sequence [Any ] | None = None ) -> Any :
287263 """Execute a query and return the first column of the first row."""
288- self ._check_in_use ()
289- protocol , db_id = self ._ensure_connected ()
290- self ._in_use = True
291- try :
292- _ , rows = await protocol .query_sql (db_id , sql , params )
293- except (DqliteConnectionError , ProtocolError ):
294- self ._invalidate ()
295- raise
296- except OperationalError as e :
297- if e .code in _LEADER_ERROR_CODES :
298- self ._invalidate ()
299- raise
300- except BaseException :
301- self ._invalidate ()
302- raise
303- finally :
304- self ._in_use = False
264+ _ , rows = await self ._run_protocol (lambda p , db : p .query_sql (db , sql , params ))
305265 if rows and rows [0 ]:
306266 return rows [0 ][0 ]
307267 return None
0 commit comments