1- import uuid
21from datetime import datetime
32from azure .cosmos .aio import CosmosClient
43from azure .cosmos import exceptions
54
6- class CosmosConversationClient ():
75
8- def __init__ (self , cosmosdb_endpoint : str , credential : any , database_name : str , container_name : str , enable_message_feedback : bool = False ):
6+ class CosmosConversationClient :
7+
8+ def __init__ (
9+ self ,
10+ cosmosdb_endpoint : str ,
11+ credential : any ,
12+ database_name : str ,
13+ container_name : str ,
14+ enable_message_feedback : bool = False ,
15+ ):
916 self .cosmosdb_endpoint = cosmosdb_endpoint
1017 self .credential = credential
1118 self .database_name = database_name
1219 self .container_name = container_name
1320 self .enable_message_feedback = enable_message_feedback
1421 try :
15- self .cosmosdb_client = CosmosClient (self .cosmosdb_endpoint , credential = credential )
22+ self .cosmosdb_client = CosmosClient (
23+ self .cosmosdb_endpoint , credential = credential
24+ )
1625 except exceptions .CosmosHttpResponseError as e :
1726 if e .status_code == 401 :
1827 raise ValueError ("Invalid credentials" ) from e
1928 else :
2029 raise ValueError ("Invalid CosmosDB endpoint" ) from e
2130
2231 try :
23- self .database_client = self .cosmosdb_client .get_database_client (database_name )
32+ self .database_client = self .cosmosdb_client .get_database_client (
33+ database_name
34+ )
2435 except exceptions .CosmosResourceNotFoundError :
2536 raise ValueError ("Invalid CosmosDB database name" )
2637
2738 try :
28- self .container_client = self .database_client .get_container_client (container_name )
39+ self .container_client = self .database_client .get_container_client (
40+ container_name
41+ )
2942 except exceptions .CosmosResourceNotFoundError :
3043 raise ValueError ("Invalid CosmosDB container name" )
3144
32-
3345 async def ensure (self ):
34- if not self .cosmosdb_client or not self .database_client or not self .container_client :
46+ if (
47+ not self .cosmosdb_client
48+ or not self .database_client
49+ or not self .container_client
50+ ):
3551 return False , "CosmosDB client not initialized correctly"
3652 try :
37- database_info = await self .database_client .read ()
38- except :
39- return False , f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found"
53+ await self .database_client .read ()
54+ except Exception :
55+ return (
56+ False ,
57+ f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found" ,
58+ )
4059
4160 try :
42- container_info = await self .container_client .read ()
43- except :
61+ await self .container_client .read ()
62+ except Exception :
4463 return False , f"CosmosDB container { self .container_name } not found"
4564
4665 return True , "CosmosDB client initialized successfully"
@@ -55,7 +74,7 @@ async def create_conversation(self, user_id, conversation_id, title=""):
5574 "title" : title ,
5675 "conversationId" : conversation_id ,
5776 }
58- ## TODO: add some error handling based on the output of the upsert_item call
77+ # TODO: add some error handling based on the output of the upsert_item call
5978 resp = await self .container_client .upsert_item (conversation )
6079 if resp :
6180 return resp
@@ -70,114 +89,109 @@ async def upsert_conversation(self, conversation):
7089 return False
7190
7291 async def delete_conversation (self , user_id , conversation_id ):
73- conversation = await self .container_client .read_item (item = conversation_id , partition_key = user_id )
92+ conversation = await self .container_client .read_item (
93+ item = conversation_id , partition_key = user_id
94+ )
7495 if conversation :
75- resp = await self .container_client .delete_item (item = conversation_id , partition_key = user_id )
96+ resp = await self .container_client .delete_item (
97+ item = conversation_id , partition_key = user_id
98+ )
7699 return resp
77100 else :
78101 return True
79102
80-
81103 async def delete_messages (self , conversation_id , user_id ):
82- ## get a list of all the messages in the conversation
104+ # get a list of all the messages in the conversation
83105 messages = await self .get_messages (user_id , conversation_id )
84106 response_list = []
85107 if messages :
86108 for message in messages :
87- resp = await self .container_client .delete_item (item = message ['id' ], partition_key = user_id )
109+ resp = await self .container_client .delete_item (
110+ item = message ["id" ], partition_key = user_id
111+ )
88112 response_list .append (resp )
89113 return response_list
90114
91-
92- async def get_conversations (self , user_id , limit , sort_order = 'DESC' , offset = 0 ):
93- parameters = [
94- {
95- 'name' : '@userId' ,
96- 'value' : user_id
97- }
98- ]
115+ async def get_conversations (self , user_id , limit , sort_order = "DESC" , offset = 0 ):
116+ parameters = [{"name" : "@userId" , "value" : user_id }]
99117 query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt { sort_order } "
100118 if limit is not None :
101119 query += f" offset { offset } limit { limit } "
102120
103121 conversations = []
104- async for item in self .container_client .query_items (query = query , parameters = parameters ):
122+ async for item in self .container_client .query_items (
123+ query = query , parameters = parameters
124+ ):
105125 conversations .append (item )
106126
107127 return conversations
108128
109129 async def get_conversation (self , user_id , conversation_id ):
110130 parameters = [
111- {
112- 'name' : '@conversationId' ,
113- 'value' : conversation_id
114- },
115- {
116- 'name' : '@userId' ,
117- 'value' : user_id
118- }
131+ {"name" : "@conversationId" , "value" : conversation_id },
132+ {"name" : "@userId" , "value" : user_id },
119133 ]
120- query = f "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
134+ query = "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
121135 conversations = []
122- async for item in self .container_client .query_items (query = query , parameters = parameters ):
136+ async for item in self .container_client .query_items (
137+ query = query , parameters = parameters
138+ ):
123139 conversations .append (item )
124140
125- ## if no conversations are found, return None
141+ # if no conversations are found, return None
126142 if len (conversations ) == 0 :
127143 return None
128144 else :
129145 return conversations [0 ]
130146
131147 async def create_message (self , uuid , conversation_id , user_id , input_message : dict ):
132148 message = {
133- 'id' : uuid ,
134- ' type' : ' message' ,
135- ' userId' : user_id ,
136- ' createdAt' : datetime .utcnow ().isoformat (),
137- ' updatedAt' : datetime .utcnow ().isoformat (),
138- ' conversationId' : conversation_id ,
139- ' role' : input_message [' role' ],
140- ' content' : input_message [' content' ]
149+ "id" : uuid ,
150+ " type" : " message" ,
151+ " userId" : user_id ,
152+ " createdAt" : datetime .utcnow ().isoformat (),
153+ " updatedAt" : datetime .utcnow ().isoformat (),
154+ " conversationId" : conversation_id ,
155+ " role" : input_message [" role" ],
156+ " content" : input_message [" content" ],
141157 }
142158
143159 if self .enable_message_feedback :
144- message [' feedback' ] = ''
160+ message [" feedback" ] = ""
145161
146162 resp = await self .container_client .upsert_item (message )
147163 if resp :
148- ## update the parent conversations's updatedAt field with the current message's createdAt datetime value
164+ # update the parent conversations's updatedAt field with the current message's createdAt datetime value
149165 conversation = await self .get_conversation (user_id , conversation_id )
150166 if not conversation :
151167 return "Conversation not found"
152- conversation [' updatedAt' ] = message [' createdAt' ]
168+ conversation [" updatedAt" ] = message [" createdAt" ]
153169 await self .upsert_conversation (conversation )
154170 return resp
155171 else :
156172 return False
157173
158174 async def update_message_feedback (self , user_id , message_id , feedback ):
159- message = await self .container_client .read_item (item = message_id , partition_key = user_id )
175+ message = await self .container_client .read_item (
176+ item = message_id , partition_key = user_id
177+ )
160178 if message :
161- message [' feedback' ] = feedback
179+ message [" feedback" ] = feedback
162180 resp = await self .container_client .upsert_item (message )
163181 return resp
164182 else :
165183 return False
166184
167185 async def get_messages (self , user_id , conversation_id ):
168186 parameters = [
169- {
170- 'name' : '@conversationId' ,
171- 'value' : conversation_id
172- },
173- {
174- 'name' : '@userId' ,
175- 'value' : user_id
176- }
187+ {"name" : "@conversationId" , "value" : conversation_id },
188+ {"name" : "@userId" , "value" : user_id },
177189 ]
178- query = f "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
190+ query = "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
179191 messages = []
180- async for item in self .container_client .query_items (query = query , parameters = parameters ):
192+ async for item in self .container_client .query_items (
193+ query = query , parameters = parameters
194+ ):
181195 messages .append (item )
182196
183197 return messages
0 commit comments