77import pprint
88import socket
99from typing import Optional , Union
10- from ..utils .search_engine import SearchEngine
1110
1211import requests
1312import sentry_sdk
1413from pydantic import BaseModel , Field , constr
1514
15+ from ..utils .search_engine import SearchEngine
16+
1617requests .packages .urllib3 .disable_warnings ()
1718
1819
@@ -33,53 +34,57 @@ def __init__(self, project_id: str, api_key: str, headers: dict = None):
3334 self ._header .update (headers )
3435
3536 def validate (self ):
36- res = requests .get (
37+ resp = requests .get (
3738 url = self ._baseurl + "/vectordb/list" ,
3839 data = None ,
3940 headers = self ._header ,
4041 verify = False ,
4142 )
42- data = res .json ()
43- res .close ()
43+ data = resp .json ()
44+ resp .close ()
45+ del resp
4446 return data
4547
4648 def get_db_list (self ):
4749 db_list = []
4850 req_url = "{}/vectordb/list" .format (self ._baseurl )
49- res = requests .get (url = req_url , data = None , headers = self ._header , verify = False )
50- status_code = res .status_code
51- body = res .json ()
51+ resp = requests .get (url = req_url , data = None , headers = self ._header , verify = False )
52+ status_code = resp .status_code
53+ body = resp .json ()
5254 if status_code == 200 and body ["statusCode" ] == 200 :
53- db_list = [db_id for db_id in res .json ()["result" ]]
54- res .close ()
55+ db_list = [db_id for db_id in body ["result" ]]
56+ resp .close ()
57+ del resp
5558 return db_list
5659
5760 def get_db_info (self , db_id : str ):
5861 req_url = "{}/vectordb/{}" .format (self ._baseurl , db_id )
59- res = requests .get (url = req_url , data = None , headers = self ._header , verify = False )
60- status_code = res .status_code
61- body = res .json ()
62- res .close ()
62+ resp = requests .get (url = req_url , data = None , headers = self ._header , verify = False )
63+ status_code = resp .status_code
64+ body = resp .json ()
65+ resp .close ()
66+ del resp
6367 return status_code , body
6468
6569 def get_db_statistics (self , db_id : str ):
6670 req_url = "{}/vectordb/{}/statistics" .format (self ._baseurl , db_id )
6771 req_data = None
68- res = requests .get (
72+ resp = requests .get (
6973 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
7074 )
71- status_code = res .status_code
72- body = res .json ()
73- res .close ()
75+ status_code = resp .status_code
76+ body = resp .json ()
77+ resp .close ()
78+ del resp
7479 return status_code , body
7580
7681 def vectordb (self , db_id : str ):
7782 # validate project_id and api_key
78- res = self .validate ()
79- if res ["statusCode" ] != 200 :
80- if res ["statusCode" ] == 404 :
83+ resp = self .validate ()
84+ if resp ["statusCode" ] != 200 :
85+ if resp ["statusCode" ] == 404 :
8186 raise Exception ("Invalid project_id" )
82- if res ["statusCode" ] == 401 :
87+ if resp ["statusCode" ] == 401 :
8388 raise Exception ("Invalid api_key" )
8489
8590 # validate db_id
@@ -95,6 +100,7 @@ def vectordb(self, db_id: str):
95100 )
96101 else :
97102 print (resp )
103+ del resp
98104 raise Exception ("Failed to get db info" )
99105
100106
@@ -123,10 +129,11 @@ def list_tables(self):
123129 if self ._db_id is None :
124130 raise Exception ("[ERROR] db_id is None!" )
125131 req_url = "{}/table/list" .format (self ._baseurl )
126- res = requests .get (url = req_url , headers = self ._header , verify = False )
127- status_code = res .status_code
128- body = res .json ()
129- res .close ()
132+ resp = requests .get (url = req_url , headers = self ._header , verify = False )
133+ status_code = resp .status_code
134+ body = resp .json ()
135+ resp .close ()
136+ del resp
130137 return status_code , body
131138
132139 # Create table
@@ -144,12 +151,13 @@ def create_table(
144151 req_data = {"name" : table_name , "fields" : table_fields }
145152 if indices is not None :
146153 req_data ["indices" ] = indices
147- res = requests .post (
154+ resp = requests .post (
148155 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
149156 )
150- status_code = res .status_code
151- body = res .json ()
152- res .close ()
157+ status_code = resp .status_code
158+ body = resp .json ()
159+ resp .close ()
160+ del resp
153161 return status_code , body
154162
155163 # Drop table
@@ -158,35 +166,38 @@ def drop_table(self, table_name: str):
158166 raise Exception ("[ERROR] db_id is None!" )
159167 req_url = "{}/table/delete?table_name={}" .format (self ._baseurl , table_name )
160168 req_data = {}
161- res = requests .delete (
169+ resp = requests .delete (
162170 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
163171 )
164- status_code = res .status_code
165- body = res .json ()
166- res .close ()
172+ status_code = resp .status_code
173+ body = resp .json ()
174+ resp .close ()
175+ del resp
167176 return status_code , body
168177
169178 # Insert data into table
170179 def insert (self , table_name : str , records : list [dict ]):
171180 req_url = "{}/data/insert" .format (self ._baseurl )
172181 req_data = {"table" : table_name , "data" : records }
173- res = requests .post (
182+ resp = requests .post (
174183 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
175184 )
176- status_code = res .status_code
177- body = res .json ()
178- res .close ()
185+ status_code = resp .status_code
186+ body = resp .json ()
187+ resp .close ()
188+ del resp
179189 return status_code , body
180190
181191 def upsert (self , table_name : str , records : list [dict ]):
182192 req_url = "{}/data/insert" .format (self ._baseurl )
183193 req_data = {"table" : table_name , "data" : records , "upsert" : True }
184- res = requests .post (
194+ resp = requests .post (
185195 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
186196 )
187- status_code = res .status_code
188- body = res .json ()
189- res .close ()
197+ status_code = resp .status_code
198+ body = resp .json ()
199+ resp .close ()
200+ del resp
190201 return status_code , body
191202
192203 # Query data from table
@@ -201,32 +212,46 @@ def query(
201212 limit : int = 2 ,
202213 filter : Optional [str ] = None ,
203214 with_distance : Optional [bool ] = False ,
215+ facets : Optional [list [dict ]] = None ,
204216 ):
205217 req_url = "{}/data/query" .format (self ._baseurl )
206- req_data = {"table" : table_name }
218+ req_data = {"table" : table_name , "limit" : limit }
219+
220+ if response_fields is None :
221+ response_fields = []
222+
207223 if query_text is not None :
208224 req_data ["query" ] = query_text
209225 if query_index is not None :
210226 req_data ["queryIndex" ] = query_index
211- if query_field != None :
227+ if query_field is not None :
212228 req_data ["queryField" ] = query_field
213- if query_vector != None :
229+ if query_vector is not None :
214230 req_data ["queryVector" ] = query_vector
215- if response_fields != None :
231+ if response_fields is not None :
216232 req_data ["response" ] = response_fields
217- if limit != None :
218- req_data ["limit" ] = limit
219- if filter != None :
233+ if filter is not None :
220234 req_data ["filter" ] = filter
221- if with_distance != None :
235+ if with_distance is not False :
222236 req_data ["withDistance" ] = with_distance
223237
224- res = requests .post (
238+ if facets is not None and len (facets ) > 0 :
239+ aggregate_not_existing = 0
240+ for facet in facets :
241+ if "aggregate" not in facet :
242+ aggregate_not_existing += 1
243+ if aggregate_not_existing > 0 :
244+ raise Exception ("[ERROR] key aggregate is a must in facets!" )
245+ else :
246+ req_data ["facets" ] = facets
247+
248+ resp = requests .post (
225249 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
226250 )
227- status_code = res .status_code
228- body = res .json ()
229- res .close ()
251+ status_code = resp .status_code
252+ body = resp .json ()
253+ resp .close ()
254+ del resp
230255 return status_code , body
231256
232257 # Delete data from table
@@ -261,12 +286,13 @@ def delete(
261286 if filter is not None :
262287 req_data ["filter" ] = filter
263288
264- res = requests .post (
289+ resp = requests .post (
265290 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
266291 )
267- status_code = res .status_code
268- body = res .json ()
269- res .close ()
292+ status_code = resp .status_code
293+ body = resp .json ()
294+ resp .close ()
295+ del resp
270296 return status_code , body
271297
272298 # Get data from table
@@ -279,6 +305,7 @@ def get(
279305 filter : Optional [str ] = None ,
280306 skip : Optional [int ] = None ,
281307 limit : Optional [int ] = None ,
308+ facets : Optional [list [dict ]] = None ,
282309 ):
283310 """Epsilla supports get records by primary keys as default for now."""
284311 if primary_keys is not None and ids is not None :
@@ -293,6 +320,7 @@ def get(
293320 primary_keys = ids
294321
295322 req_data = {"table" : table_name }
323+
296324 if response_fields is not None :
297325 req_data ["response" ] = response_fields
298326 if primary_keys is not None :
@@ -304,13 +332,24 @@ def get(
304332 if limit is not None :
305333 req_data ["limit" ] = limit
306334
335+ if facets is not None and len (facets ) > 0 :
336+ aggregate_not_existing = 0
337+ for facet in facets :
338+ if "aggregate" not in facet :
339+ aggregate_not_existing += 1
340+ if aggregate_not_existing > 0 :
341+ raise Exception ("[ERROR] key aggregate is a must in facets!" )
342+ else :
343+ req_data ["facets" ] = facets
344+
307345 req_url = "{}/data/get" .format (self ._baseurl )
308- res = requests .post (
346+ resp = requests .post (
309347 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
310348 )
311- status_code = res .status_code
312- body = res .json ()
313- res .close ()
349+ status_code = resp .status_code
350+ body = resp .json ()
351+ resp .close ()
352+ del resp
314353 return status_code , body
315354
316355 def as_search_engine (self ):
0 commit comments