55from contextlib import asynccontextmanager
66from typing import Any
77
8- from dqliteclient .exceptions import ConnectionError
8+ from dqliteclient .exceptions import ConnectionError , ProtocolError
99from dqliteclient .protocol import DqliteProtocol
1010
1111
@@ -92,24 +92,41 @@ def _ensure_connected(self) -> tuple[DqliteProtocol, int]:
9292 raise ConnectionError ("Not connected" )
9393 return self ._protocol , self ._db_id
9494
95+ def _invalidate (self ) -> None :
96+ """Mark the connection as broken after an unrecoverable error."""
97+ self ._protocol = None
98+ self ._db_id = None
99+
95100 async def execute (self , sql : str , params : list [Any ] | None = None ) -> tuple [int , int ]:
96101 """Execute a SQL statement.
97102
98103 Returns (last_insert_id, rows_affected).
99104 """
100105 protocol , db_id = self ._ensure_connected ()
101- return await protocol .exec_sql (db_id , sql , params )
106+ try :
107+ return await protocol .exec_sql (db_id , sql , params )
108+ except (ConnectionError , ProtocolError ):
109+ self ._invalidate ()
110+ raise
102111
103112 async def fetch (self , sql : str , params : list [Any ] | None = None ) -> list [dict [str , Any ]]:
104113 """Execute a query and return results as list of dicts."""
105114 protocol , db_id = self ._ensure_connected ()
106- columns , rows = await protocol .query_sql (db_id , sql , params )
115+ try :
116+ columns , rows = await protocol .query_sql (db_id , sql , params )
117+ except (ConnectionError , ProtocolError ):
118+ self ._invalidate ()
119+ raise
107120 return [dict (zip (columns , row , strict = True )) for row in rows ]
108121
109122 async def fetchall (self , sql : str , params : list [Any ] | None = None ) -> list [list [Any ]]:
110123 """Execute a query and return results as list of lists."""
111124 protocol , db_id = self ._ensure_connected ()
112- _ , rows = await protocol .query_sql (db_id , sql , params )
125+ try :
126+ _ , rows = await protocol .query_sql (db_id , sql , params )
127+ except (ConnectionError , ProtocolError ):
128+ self ._invalidate ()
129+ raise
113130 return rows
114131
115132 async def fetchone (self , sql : str , params : list [Any ] | None = None ) -> dict [str , Any ] | None :
@@ -120,7 +137,11 @@ async def fetchone(self, sql: str, params: list[Any] | None = None) -> dict[str,
120137 async def fetchval (self , sql : str , params : list [Any ] | None = None ) -> Any :
121138 """Execute a query and return the first column of the first row."""
122139 protocol , db_id = self ._ensure_connected ()
123- _ , rows = await protocol .query_sql (db_id , sql , params )
140+ try :
141+ _ , rows = await protocol .query_sql (db_id , sql , params )
142+ except (ConnectionError , ProtocolError ):
143+ self ._invalidate ()
144+ raise
124145 if rows and rows [0 ]:
125146 return rows [0 ][0 ]
126147 return None
0 commit comments