55from contextlib import asynccontextmanager
66from typing import Any
77
8- from dqliteclient .exceptions import ConnectionError , OperationalError , ProtocolError
8+ from dqliteclient .exceptions import DqliteConnectionError , OperationalError , ProtocolError
99from dqliteclient .protocol import DqliteProtocol
1010
1111
@@ -57,9 +57,9 @@ async def connect(self) -> None:
5757 timeout = self ._timeout ,
5858 )
5959 except TimeoutError as e :
60- raise ConnectionError (f"Connection to { self ._address } timed out" ) from e
60+ raise DqliteConnectionError (f"Connection to { self ._address } timed out" ) from e
6161 except OSError as e :
62- raise ConnectionError (f"Failed to connect to { self ._address } : { e } " ) from e
62+ raise DqliteConnectionError (f"Failed to connect to { self ._address } : { e } " ) from e
6363
6464 self ._protocol = DqliteProtocol (reader , writer )
6565
@@ -89,7 +89,7 @@ async def __aexit__(self, *args: Any) -> None:
8989 def _ensure_connected (self ) -> tuple [DqliteProtocol , int ]:
9090 """Ensure we're connected and return protocol and db_id."""
9191 if self ._protocol is None or self ._db_id is None :
92- raise ConnectionError ("Not connected" )
92+ raise DqliteConnectionError ("Not connected" )
9393 return self ._protocol , self ._db_id
9494
9595 def _invalidate (self ) -> None :
@@ -105,7 +105,7 @@ async def execute(self, sql: str, params: list[Any] | None = None) -> tuple[int,
105105 protocol , db_id = self ._ensure_connected ()
106106 try :
107107 return await protocol .exec_sql (db_id , sql , params )
108- except (ConnectionError , ProtocolError ):
108+ except (DqliteConnectionError , ProtocolError ):
109109 self ._invalidate ()
110110 raise
111111
@@ -114,7 +114,7 @@ async def fetch(self, sql: str, params: list[Any] | None = None) -> list[dict[st
114114 protocol , db_id = self ._ensure_connected ()
115115 try :
116116 columns , rows = await protocol .query_sql (db_id , sql , params )
117- except (ConnectionError , ProtocolError ):
117+ except (DqliteConnectionError , ProtocolError ):
118118 self ._invalidate ()
119119 raise
120120 return [dict (zip (columns , row , strict = True )) for row in rows ]
@@ -124,7 +124,7 @@ async def fetchall(self, sql: str, params: list[Any] | None = None) -> list[list
124124 protocol , db_id = self ._ensure_connected ()
125125 try :
126126 _ , rows = await protocol .query_sql (db_id , sql , params )
127- except (ConnectionError , ProtocolError ):
127+ except (DqliteConnectionError , ProtocolError ):
128128 self ._invalidate ()
129129 raise
130130 return rows
@@ -139,7 +139,7 @@ async def fetchval(self, sql: str, params: list[Any] | None = None) -> Any:
139139 protocol , db_id = self ._ensure_connected ()
140140 try :
141141 _ , rows = await protocol .query_sql (db_id , sql , params )
142- except (ConnectionError , ProtocolError ):
142+ except (DqliteConnectionError , ProtocolError ):
143143 self ._invalidate ()
144144 raise
145145 if rows and rows [0 ]:
0 commit comments