77import pprint
88import socket
99from typing import Optional , Union
10- from ..utils .search_engine import SearchEngine
1110
1211import requests
1312import sentry_sdk
1413from pydantic import BaseModel , Field , constr
1514
15+ from ..utils .search_engine import SearchEngine
16+
1617requests .packages .urllib3 .disable_warnings () # type: ignore
1718
1819
@@ -26,7 +27,9 @@ class DbModel(BaseModel):
2627
2728
2829class Client (cloud .Client ):
29- def __init__ (self , base_url : str , project_id : Optional [str ] = "default" , headers : dict = None ):
30+ def __init__ (
31+ self , base_url : str , project_id : Optional [str ] = "default" , headers : dict = None
32+ ):
3033 self ._project_id = project_id
3134 self ._baseurl = f"{ base_url } /api/v3/project/{ project_id } "
3235 self ._timeout = 10
@@ -176,7 +179,12 @@ def list_tables(self):
176179 return status_code , body
177180
178181 # Create table
179- def create_table (self , table_name : str , table_fields : list [dict ] = None , indices : list [dict ] = None ):
182+ def create_table (
183+ self ,
184+ table_name : str ,
185+ table_fields : list [dict ] = None ,
186+ indices : list [dict ] = None ,
187+ ):
180188 if self ._db_id is None :
181189 raise Exception ("[ERROR] db_id is None!" )
182190 if table_fields is None :
@@ -250,32 +258,43 @@ def query(
250258 limit : int = 2 ,
251259 filter : Optional [str ] = None ,
252260 with_distance : Optional [bool ] = False ,
261+ facets : Optional [list [dict ]] = None ,
253262 ):
254263 req_url = "{}/data/query" .format (self ._baseurl )
255264 req_data = {"table" : table_name }
256265 if query_text is not None :
257266 req_data ["query" ] = query_text
258267 if query_index is not None :
259268 req_data ["queryIndex" ] = query_index
260- if query_field != None :
269+ if query_field is not None :
261270 req_data ["queryField" ] = query_field
262- if query_vector != None :
271+ if query_vector is not None :
263272 req_data ["queryVector" ] = query_vector
264- if response_fields != None :
273+ if response_fields is not None :
265274 req_data ["response" ] = response_fields
266- if limit != None :
275+ if limit is not None :
267276 req_data ["limit" ] = limit
268- if filter != None :
277+ if filter is not None :
269278 req_data ["filter" ] = filter
270279 if with_distance is not False :
271280 req_data ["withDistance" ] = with_distance
281+ if facets is not None and len (facets ) > 0 :
282+ aggregate_not_existing = 0
283+ for facet in facets :
284+ if "aggregate" not in facet :
285+ aggregate_not_existing += 1
286+ if aggregate_not_existing > 0 :
287+ raise Exception ("[ERROR] key aggregate is a must in facets!" )
288+ else :
289+ req_data ["facets" ] = facets
272290
273291 res = requests .post (
274292 url = req_url , data = json .dumps (req_data ), headers = self ._header , verify = False
275293 )
276294 status_code = res .status_code
277295 body = res .json ()
278296 res .close ()
297+ del res
279298 return status_code , body
280299
281300 # Delete data from table
0 commit comments