Skip to content

Commit 8b099b4

Browse files
committed
update Facets
1 parent 602a958 commit 8b099b4

2 files changed

Lines changed: 32 additions & 12 deletions

File tree

pyepsilla/enterprise/client.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import pprint
88
import socket
99
from typing import Optional, Union
10-
from ..utils.search_engine import SearchEngine
1110

1211
import requests
1312
import sentry_sdk
1413
from pydantic import BaseModel, Field, constr
1514

15+
from ..utils.search_engine import SearchEngine
16+
1617
requests.packages.urllib3.disable_warnings() # type: ignore
1718

1819

@@ -26,7 +27,9 @@ class DbModel(BaseModel):
2627

2728

2829
class Client(cloud.Client):
29-
def __init__(self, base_url: str, project_id: Optional[str] = "default", headers: dict = None):
30+
def __init__(
31+
self, base_url: str, project_id: Optional[str] = "default", headers: dict = None
32+
):
3033
self._project_id = project_id
3134
self._baseurl = f"{base_url}/api/v3/project/{project_id}"
3235
self._timeout = 10
@@ -176,7 +179,12 @@ def list_tables(self):
176179
return status_code, body
177180

178181
# Create table
179-
def create_table(self, table_name: str, table_fields: list[dict] = None, indices: list[dict] = None):
182+
def create_table(
183+
self,
184+
table_name: str,
185+
table_fields: list[dict] = None,
186+
indices: list[dict] = None,
187+
):
180188
if self._db_id is None:
181189
raise Exception("[ERROR] db_id is None!")
182190
if table_fields is None:
@@ -250,32 +258,43 @@ def query(
250258
limit: int = 2,
251259
filter: Optional[str] = None,
252260
with_distance: Optional[bool] = False,
261+
facets: Optional[list[dict]] = None,
253262
):
254263
req_url = "{}/data/query".format(self._baseurl)
255264
req_data = {"table": table_name}
256265
if query_text is not None:
257266
req_data["query"] = query_text
258267
if query_index is not None:
259268
req_data["queryIndex"] = query_index
260-
if query_field != None:
269+
if query_field is not None:
261270
req_data["queryField"] = query_field
262-
if query_vector != None:
271+
if query_vector is not None:
263272
req_data["queryVector"] = query_vector
264-
if response_fields != None:
273+
if response_fields is not None:
265274
req_data["response"] = response_fields
266-
if limit != None:
275+
if limit is not None:
267276
req_data["limit"] = limit
268-
if filter != None:
277+
if filter is not None:
269278
req_data["filter"] = filter
270279
if with_distance is not False:
271280
req_data["withDistance"] = with_distance
281+
if facets is not None and len(facets) > 0:
282+
aggregate_not_existing = 0
283+
for facet in facets:
284+
if "aggregate" not in facet:
285+
aggregate_not_existing += 1
286+
if aggregate_not_existing > 0:
287+
raise Exception("[ERROR] key aggregate is a must in facets!")
288+
else:
289+
req_data["facets"] = facets
272290

273291
res = requests.post(
274292
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
275293
)
276294
status_code = res.status_code
277295
body = res.json()
278296
res.close()
297+
del res
279298
return status_code, body
280299

281300
# Delete data from table

pyepsilla/vectordb/client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,11 @@ def query(
250250
query_index: str = None,
251251
query_field: str = None,
252252
query_vector: Union[list, dict] = None,
253-
response_fields: list = None,
253+
response_fields: Optional[list] = None,
254254
limit: int = 2,
255-
filter: str = "",
256-
with_distance: bool = False,
257-
facets: list[dict] = None,
255+
filter: Optional[str] = None,
256+
with_distance: Optional[bool] = False,
257+
facets: Optional[list[dict]] = None,
258258
):
259259
if self._db is None:
260260
raise Exception("[ERROR] Please use_db() first!")
@@ -291,6 +291,7 @@ def query(
291291
status_code = res.status_code
292292
body = res.json()
293293
res.close()
294+
del res
294295
return status_code, body
295296

296297
def get(

0 commit comments

Comments
 (0)