Skip to content

Commit bc67d36

Browse files
authored
Merge pull request #36 from epsilla-cloud/dev
fix the issue of limit of pyepsilla
2 parents 31c2c05 + ab4b2fa commit bc67d36

5 files changed

Lines changed: 225 additions & 102 deletions

File tree

examples/hello_epsilla.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pyepsilla import vectordb
1111

1212
# Connect to Epsilla VectorDB
13-
client = vectordb.Client(protocol='http', host='127.0.0.1', port='8888')
13+
client = vectordb.Client(protocol="http", host="127.0.0.1", port="8888")
1414

1515
# You can also use Epsilla Cloud
1616
# client = vectordb.Client(protocol='https', host='demo.epsilla.com', port='443')
@@ -25,12 +25,12 @@
2525

2626
# Create a table with schema in current DB
2727
status_code, response = client.create_table(
28-
table_name="MyTable",
29-
table_fields=[
30-
{"name": "ID", "dataType": "INT", "primaryKey": True},
31-
{"name": "Doc", "dataType": "STRING"},
32-
{"name": "Embedding", "dataType": "VECTOR_FLOAT", "dimensions": 4}
33-
]
28+
table_name="MyTable",
29+
table_fields=[
30+
{"name": "ID", "dataType": "INT", "primaryKey": True},
31+
{"name": "Doc", "dataType": "STRING"},
32+
{"name": "Embedding", "dataType": "VECTOR_FLOAT", "dimensions": 4},
33+
],
3434
)
3535
print(response)
3636

@@ -40,43 +40,45 @@
4040

4141
# Insert new vector records into table
4242
status_code, response = client.insert(
43-
table_name="MyTable",
44-
records=[
45-
{"ID": 1, "Doc": "Berlin", "Embedding": [0.05, 0.61, 0.76, 0.74]},
46-
{"ID": 2, "Doc": "London", "Embedding": [0.19, 0.81, 0.75, 0.11]},
47-
{"ID": 3, "Doc": "Moscow", "Embedding": [0.36, 0.55, 0.47, 0.94]},
48-
{"ID": 4, "Doc": "San Francisco", "Embedding": [0.18, 0.01, 0.85, 0.80]},
49-
{"ID": 5, "Doc": "Shanghai", "Embedding": [0.24, 0.18, 0.22, 0.44]}
50-
]
43+
table_name="MyTable",
44+
records=[
45+
{"ID": 1, "Doc": "Berlin", "Embedding": [0.05, 0.61, 0.76, 0.74]},
46+
{"ID": 2, "Doc": "London", "Embedding": [0.19, 0.81, 0.75, 0.11]},
47+
{"ID": 3, "Doc": "Moscow", "Embedding": [0.36, 0.55, 0.47, 0.94]},
48+
{"ID": 4, "Doc": "San Francisco", "Embedding": [0.18, 0.01, 0.85, 0.80]},
49+
{"ID": 5, "Doc": "Shanghai", "Embedding": [0.24, 0.18, 0.22, 0.44]},
50+
],
5151
)
5252
print(response)
5353

5454
# Query Vectors with specific response field
5555
status_code, response = client.query(
56-
table_name="MyTable",
57-
query_field="Embedding",
58-
query_vector=[0.35, 0.55, 0.47, 0.94],
59-
response_fields = ["Doc"],
60-
limit=2
56+
table_name="MyTable",
57+
query_field="Embedding",
58+
query_vector=[0.35, 0.55, 0.47, 0.94],
59+
response_fields=["Doc"],
60+
limit=2,
6161
)
6262

6363
# Query Vectors without specific response field, then it will return all fields
6464
status_code, response = client.query(
65-
table_name="MyTable",
66-
query_field="Embedding",
67-
query_vector=[0.35, 0.55, 0.47, 0.94],
68-
limit=2
65+
table_name="MyTable",
66+
query_field="Embedding",
67+
query_vector=[0.35, 0.55, 0.47, 0.94],
68+
limit=2,
6969
)
7070
print(response)
7171

72+
# Get Vectors
73+
status_code, response = client.get(table_name="MyTable", limit=2)
74+
print(response)
7275

7376
# status_code, response = client.delete(table_name="MyTable", ids=[3])
74-
status_code, response = client.delete(table_name="MyTable", primary_keys=[3, 4])
77+
status_code, response = client.delete(table_name="MyTable", primary_keys=[3, 4])
7578
# status_code, response = client.delete(table_name="MyTable", filter="Doc <> 'San Francisco'")
7679
print(response)
7780

7881

79-
8082
# Drop table
8183
# status_code, response = client.drop_table("MyTable")
8284
# print(response)

pyepsilla/cloud/client.py

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,52 @@
11
#!/usr/bin/env python
22
# -*- coding:utf-8 -*-
3-
import json, datetime, socket, requests, json, pprint
4-
from typing import Union, Optional
3+
from __future__ import annotations
4+
5+
import datetime
6+
import json
7+
import pprint
8+
import socket
9+
from typing import Optional, Union
10+
11+
import requests
512
import sentry_sdk
13+
614
requests.packages.urllib3.disable_warnings()
715

16+
817
class Client(object):
918
def __init__(self, project_id: str, api_key: str):
1019
self._project_id = project_id
1120
self._apikey = api_key
12-
self._baseurl = "https://dispatch.epsilla.com/api/v2/project/{}".format(self._project_id)
21+
self._baseurl = "https://dispatch.epsilla.com/api/v2/project/{}".format(
22+
self._project_id
23+
)
1324
self._timeout = 10
14-
self._header = {'Content-type': 'application/json', "Connection": "close", 'X-API-Key': api_key}
15-
25+
self._header = {
26+
"Content-type": "application/json",
27+
"Connection": "close",
28+
"X-API-Key": api_key,
29+
}
1630

1731
def validate(self):
18-
res = requests.get(url=self._baseurl, data=None, headers=self._header, verify=False)
32+
res = requests.get(
33+
url=self._baseurl, data=None, headers=self._header, verify=False
34+
)
1935
data = res.json()
2036
res.close()
2137
return data
2238

23-
2439
def get_db_list(self):
2540
db_list = []
2641
req_url = "{}/vectordb/list".format(self._baseurl)
2742
res = requests.get(url=req_url, data=None, headers=self._header, verify=False)
2843
status_code = res.status_code
2944
body = res.json()
3045
if status_code == 200 and body["statusCode"] == 200:
31-
db_list = [ db_id for db_id in res.json()["result"] ]
46+
db_list = [db_id for db_id in res.json()["result"]]
3247
res.close()
3348
return db_list
3449

35-
3650
def get_db_info(self, db_id: str):
3751
req_url = "{}/vectordb/{}".format(self._baseurl, db_id)
3852
res = requests.get(url=req_url, data=None, headers=self._header, verify=False)
@@ -41,7 +55,6 @@ def get_db_info(self, db_id: str):
4155
res.close()
4256
return status_code, body
4357

44-
4558
def vectordb(self, db_id: str):
4659
## validate project_id and api_key
4760
res = self.validate()
@@ -54,33 +67,37 @@ def vectordb(self, db_id: str):
5467
## validate db_id
5568
if not db_id in self.get_db_list():
5669
raise Exception("Invalid db_id")
57-
70+
5871
## fetch db public endpoint
5972
status_code, resp = self.get_db_info(db_id=db_id)
6073
if resp["statusCode"] == 200:
61-
return Vectordb(self._project_id, db_id, self._apikey, resp["result"]["public_endpoint"])
74+
return Vectordb(
75+
self._project_id, db_id, self._apikey, resp["result"]["public_endpoint"]
76+
)
6277
else:
6378
print(resp)
6479
raise Exception("Failed to get db info")
65-
80+
6681
## TODO
67-
def create_db(self, db_id: str, db_type: str, db_name: str, db_description: str = ""):
82+
def create_db(
83+
self, db_id: str, db_type: str, db_name: str, db_description: str = ""
84+
):
6885
pass
6986

7087
def delete_db(self, db_id: str):
7188
pass
7289

7390

74-
7591
class Vectordb(Client):
7692
def __init__(self, project_id: str, db_id: str, api_key: str, public_endpoint: str):
7793
self._project_id = project_id
7894
self._db_id = db_id
7995
self._api_key = api_key
8096
self._public_endpoint = public_endpoint
81-
self._baseurl = "https://{}/api/v2/project/{}/vectordb/{}".format(self._public_endpoint, self._project_id, self._db_id)
82-
self._header = {'Content-type': 'application/json', 'X-API-Key': self._api_key}
83-
97+
self._baseurl = "https://{}/api/v2/project/{}/vectordb/{}".format(
98+
self._public_endpoint, self._project_id, self._db_id
99+
)
100+
self._header = {"Content-type": "application/json", "X-API-Key": self._api_key}
84101

85102
## TODO
86103
## create table
@@ -91,22 +108,31 @@ def create_table(self, table_name: str = "MyTable", table_fields: list[str] = No
91108
def drop_table(self, table_name: str):
92109
pass
93110

94-
95111
## insert data into table
96112
def insert(self, table_name: str, records: list[dict]):
97113
req_url = "{}/data/insert".format(self._baseurl)
98114
req_data = {"table": table_name, "data": records}
99-
res = requests.post(url=req_url, data=json.dumps(req_data), headers=self._header, verify=False)
115+
res = requests.post(
116+
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
117+
)
100118
status_code = res.status_code
101119
body = res.json()
102120
res.close()
103121
return status_code, body
104122

105-
106123
## query data from table
107-
def query(self, table_name: str, query_field: str = None, query_vector: Union[list,dict] = None, response_fields: Optional[list] = None, limit: int = 2, filter: Optional[str] = None, with_distance: Optional[bool] = False):
124+
def query(
125+
self,
126+
table_name: str,
127+
query_field: str = None,
128+
query_vector: Union[list, dict] = None,
129+
response_fields: Optional[list] = None,
130+
limit: int = 2,
131+
filter: Optional[str] = None,
132+
with_distance: Optional[bool] = False,
133+
):
108134
req_url = "{}/data/query".format(self._baseurl)
109-
req_data = { "table": table_name }
135+
req_data = {"table": table_name}
110136
if query_field != None:
111137
req_data["queryField"] = query_field
112138
if query_vector != None:
@@ -120,51 +146,74 @@ def query(self, table_name: str, query_field: str = None, query_vector: Union[li
120146
if with_distance != None:
121147
req_data["withDistance"] = with_distance
122148

123-
res = requests.post(url=req_url, data=json.dumps(req_data), headers=self._header, verify=False)
149+
res = requests.post(
150+
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
151+
)
124152
status_code = res.status_code
125153
body = res.json()
126154
res.close()
127155
return status_code, body
128156

129-
130157
## delete data from table
131-
def delete(self, table_name: str, primary_keys: Optional[list[Union[str,int]]] = None, ids: Optional[list[Union[str,int]]] = None, filter: Optional[str] = None):
158+
def delete(
159+
self,
160+
table_name: str,
161+
primary_keys: Optional[list[Union[str, int]]] = None,
162+
ids: Optional[list[Union[str, int]]] = None,
163+
filter: Optional[str] = None,
164+
):
132165
"""Epsilla supports delete records by primary keys as default for now."""
133166
if filter == None:
134167
if primary_keys == None and ids == None:
135-
raise Exception("[ERROR] Please provide at least one of primary keys(ids) and filter to delete record(s).")
168+
raise Exception(
169+
"[ERROR] Please provide at least one of primary keys(ids) and filter to delete record(s)."
170+
)
136171
if primary_keys == None and ids != None:
137172
primary_keys = ids
138173
if primary_keys != None and ids != None:
139174
try:
140175
sentry_sdk.sdk("Duplicate Keys with both primary keys and ids", "info")
141176
except Exception as e:
142177
pass
143-
print("[WARN] Both primary_keys and ids are prvoided, will use primary keys by default!")
178+
print(
179+
"[WARN] Both primary_keys and ids are prvoided, will use primary keys by default!"
180+
)
144181

145182
req_url = "{}/data/delete".format(self._baseurl)
146-
req_data = { "table": table_name }
183+
req_data = {"table": table_name}
147184
if primary_keys != None:
148185
req_data["primaryKeys"] = primary_keys
149186
if filter != None:
150187
req_data["filter"] = filter
151188

152-
res = requests.post(url=req_url, data=json.dumps(req_data), headers=self._header, verify=False)
189+
res = requests.post(
190+
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
191+
)
153192
status_code = res.status_code
154193
body = res.json()
155194
res.close()
156195
return status_code, body
157196

158-
159197
## get data from table
160-
def get(self, table_name: str, response_fields: Optional[list] = None, primary_keys: Optional[list[Union[str,int]]] = None, ids: Optional[list[Union[str,int]]] = None, filter: Optional[str] = None, skip: Optional[int] = None, limit: Optional[int] = None):
198+
def get(
199+
self,
200+
table_name: str,
201+
response_fields: Optional[list] = None,
202+
primary_keys: Optional[list[Union[str, int]]] = None,
203+
ids: Optional[list[Union[str, int]]] = None,
204+
filter: Optional[str] = None,
205+
skip: Optional[int] = None,
206+
limit: Optional[int] = None,
207+
):
161208
"""Epsilla supports get records by primary keys as default for now."""
162209
if primary_keys != None and ids != None:
163210
try:
164211
sentry_sdk.sdk("Duplicate Keys with both primary_keys and ids", "info")
165212
except Exception as e:
166213
pass
167-
print("[WARN]Both primary_keys and ids are prvoided, will use primary keys by default!")
214+
print(
215+
"[WARN]Both primary_keys and ids are prvoided, will use primary keys by default!"
216+
)
168217
if primary_keys == None and ids != None:
169218
primary_keys = ids
170219

@@ -181,12 +230,10 @@ def get(self, table_name: str, response_fields: Optional[list] = None, primary_k
181230
req_data["limit"] = limit
182231

183232
req_url = "{}/data/get".format(self._baseurl)
184-
res = requests.post(url=req_url, data=json.dumps(req_data), headers=self._header, verify=False)
233+
res = requests.post(
234+
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
235+
)
185236
status_code = res.status_code
186237
body = res.json()
187238
res.close()
188239
return status_code, body
189-
190-
191-
192-

pyepsilla/enterprise/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding:utf-8 -*-
3+
from __future__ import annotations
4+
35
import datetime
46
import json
57
import pprint
@@ -221,7 +223,7 @@ def query(
221223
self,
222224
table_name: str,
223225
query_field: str = None,
224-
query_vector: Union[list,dict] = None,
226+
query_vector: Union[list, dict] = None,
225227
response_fields: Optional[list] = None,
226228
limit: int = 2,
227229
filter: Optional[str] = None,

0 commit comments

Comments
 (0)