Skip to content

Commit 93b1d1f

Browse files
committed
support create/drop table for cloud
1 parent ab4b2fa commit 93b1d1f

2 files changed

Lines changed: 164 additions & 38 deletions

File tree

examples/hello_epsilla_cloud.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,30 @@
66
# 2. create a table with schema in db
77
# 3. get the api key with project id, run this program
88

9+
import sys
910

1011
from pyepsilla import cloud
1112

1213
# Connect to Epsilla Cloud
1314
client = cloud.Client(
14-
project_id="32ef3a3f-fcb0-4c4b-98bb-fca01bca0d0a", api_key="epsilla"
15+
project_id="7a68814c-f839-4a67-9ec6-93c027c865e6",
16+
api_key="epsilla-cloud-api-key",
1517
)
1618

1719
# Connect to Vectordb
18-
db = client.vectordb(db_id="df7431d0-806b-4654-8b45-4bdb20038e26")
20+
db = client.vectordb(db_id="6accafb1-476d-43b0-aa64-12ebfbf7442d")
1921

2022

21-
# Create a table with schema on Epsilla Cloud Console
23+
# Create a table with schema
24+
status_code, response = db.create_table(
25+
table_name="MyTable",
26+
table_fields=[
27+
{"name": "ID", "dataType": "INT", "primaryKey": True},
28+
{"name": "Doc", "dataType": "STRING"},
29+
{"name": "Embedding", "dataType": "VECTOR_FLOAT", "dimensions": 4},
30+
],
31+
)
32+
print(status_code, response)
2233

2334

2435
# Insert new vector records into table
@@ -49,3 +60,7 @@
4960
status_code, response = db.delete(table_name="MyTable", primary_keys=[4, 5])
5061
status_code, response = db.delete(table_name="MyTable", filter="Doc <> 'San Francisco'")
5162
print(status_code, response)
63+
64+
# Drop table
65+
status_code, response = db.drop_table(table_name="MyTable")
66+
print(status_code, response)

pyepsilla/cloud/client.py

Lines changed: 146 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import requests
1212
import sentry_sdk
13+
from pydantic import BaseModel, Field, constr
1314

1415
requests.packages.urllib3.disable_warnings()
1516

@@ -18,7 +19,7 @@ class Client(object):
1819
def __init__(self, project_id: str, api_key: str):
1920
self._project_id = project_id
2021
self._apikey = api_key
21-
self._baseurl = "https://dispatch.epsilla.com/api/v2/project/{}".format(
22+
self._baseurl = "https://dispatch.epsilla.com/api/v3/project/{}".format(
2223
self._project_id
2324
)
2425
self._timeout = 10
@@ -30,7 +31,10 @@ def __init__(self, project_id: str, api_key: str):
3031

3132
def validate(self):
3233
res = requests.get(
33-
url=self._baseurl, data=None, headers=self._header, verify=False
34+
url=self._baseurl + "/vectordb/list",
35+
data=None,
36+
headers=self._header,
37+
verify=False,
3438
)
3539
data = res.json()
3640
res.close()
@@ -56,19 +60,20 @@ def get_db_info(self, db_id: str):
5660
return status_code, body
5761

5862
def vectordb(self, db_id: str):
59-
## validate project_id and api_key
63+
# validate project_id and api_key
6064
res = self.validate()
6165
if res["statusCode"] != 200:
6266
if res["statusCode"] == 404:
6367
raise Exception("Invalid project_id")
6468
if res["statusCode"] == 401:
6569
raise Exception("Invalid api_key")
6670

67-
## validate db_id
68-
if not db_id in self.get_db_list():
71+
# validate db_id
72+
db_list = self.get_db_list()
73+
if db_id not in db_list:
6974
raise Exception("Invalid db_id")
7075

71-
## fetch db public endpoint
76+
# fetch db public endpoint
7277
status_code, resp = self.get_db_info(db_id=db_id)
7378
if resp["statusCode"] == 200:
7479
return Vectordb(
@@ -78,14 +83,86 @@ def vectordb(self, db_id: str):
7883
print(resp)
7984
raise Exception("Failed to get db info")
8085

81-
## TODO
86+
# Create DB
8287
def create_db(
83-
self, db_id: str, db_type: str, db_name: str, db_description: str = ""
88+
self,
89+
db_name: str = Field(pattern=r"^[a-zA-Z-0-9]{4,32}$", strict=True),
90+
db_id: Optional[str] = None,
91+
project_id: Optional[str] = "default",
92+
min_replicas: Optional[int] = 0,
93+
max_replicas: Optional[int] = 1,
94+
sharding_init_number: Optional[int] = 1,
95+
sharding_increase_step: Optional[int] = 2,
96+
sharding_capacity: Optional[int] = 150000,
97+
sharding_increase_threshold: Optional[float] = 0.9,
8498
):
85-
pass
99+
req_url = "{}/vectordb/create".format(self._baseurl)
100+
req_data = {
101+
"db_name": db_name,
102+
"db_uuid": db_id,
103+
"project_id": project_id,
104+
"min_replicas": min_replicas,
105+
"max_replicas": max_replicas,
106+
"sharding_init_number": sharding_init_number,
107+
"sharding_increase_step": sharding_increase_step,
108+
"sharding_capacity": sharding_capacity,
109+
"sharding_increase_threshold": sharding_increase_threshold,
110+
}
111+
resp = requests.post(
112+
url=req_url,
113+
data=json.dumps(req_data),
114+
headers=self._header,
115+
verify=False,
116+
)
117+
status_code = resp.status_code
118+
body = resp.json()
119+
resp.close()
120+
return status_code, body
121+
122+
# Load DB
123+
def load_db(self, db_id: str):
124+
req_url = "{}/vectordb/{}/load".format(self._baseurl, db_id)
125+
req_data = {}
126+
resp = requests.post(
127+
url=req_url,
128+
data=json.dumps(req_data),
129+
headers=self._header,
130+
verify=False,
131+
)
132+
status_code = resp.status_code
133+
body = resp.json()
134+
resp.close()
135+
return status_code, body
136+
137+
# Unload DB
138+
def unload_db(self, db_id: str):
139+
req_url = "{}/vectordb/{}/unload".format(self._baseurl, db_id)
140+
req_data = {}
141+
resp = requests.post(
142+
url=req_url,
143+
data=json.dumps(req_data),
144+
headers=self._header,
145+
verify=False,
146+
)
147+
status_code = resp.status_code
148+
body = resp.json()
149+
resp.close()
150+
return status_code, body
86151

87-
def delete_db(self, db_id: str):
88-
pass
152+
# Delete DB
153+
def drop_db(self, db_id: str):
154+
req_url = "{}/vectordb/{}".format(self._baseurl, db_id)
155+
req_data = {}
156+
resp = requests.delete(
157+
url=req_url,
158+
data=json.dumps(req_data),
159+
headers=self._header,
160+
verify=False,
161+
)
162+
status_code = resp.status_code
163+
body = resp.json()
164+
resp.close()
165+
return status_code, body
89166

90167

91168
class Vectordb(Client):
@@ -94,24 +171,58 @@ def __init__(self, project_id: str, db_id: str, api_key: str, public_endpoint: s
94171
self._db_id = db_id
95172
self._api_key = api_key
96173
self._public_endpoint = public_endpoint
97-
self._baseurl = "https://{}/api/v2/project/{}/vectordb/{}".format(
174+
self._baseurl = "https://{}/api/v3/project/{}/vectordb/{}".format(
98175
self._public_endpoint, self._project_id, self._db_id
99176
)
100177
self._header = {"Content-type": "application/json", "X-API-Key": self._api_key}
101178

102-
## TODO
103-
## create table
104-
def create_table(self, table_name: str = "MyTable", table_fields: list[str] = None):
105-
pass
179+
# List table
180+
def list_tables(self):
181+
if self._db_id is None:
182+
raise Exception("[ERROR] db_id is None!")
183+
req_url = "{}/table/list".format(self._baseurl)
184+
res = requests.get(url=req_url, headers=self._header, verify=False)
185+
status_code = res.status_code
186+
body = res.json()
187+
res.close()
188+
return status_code, body
106189

107-
## drop table
190+
# Create table
191+
def create_table(self, table_name: str, table_fields: list[str] = None):
192+
if self._db_id is None:
193+
raise Exception("[ERROR] db_id is None!")
194+
if table_fields is None:
195+
table_fields = []
196+
req_url = "{}/table/create".format(self._baseurl)
197+
req_data = {"table_name": table_name, "fields": table_fields}
198+
res = requests.post(
199+
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
200+
)
201+
status_code = res.status_code
202+
body = res.json()
203+
res.close()
204+
return status_code, body
205+
206+
# Drop table
108207
def drop_table(self, table_name: str):
109-
pass
208+
if self._db_id is None:
209+
raise Exception("[ERROR] db_id is None!")
210+
req_url = "{}/table/delete?table_name={}".format(self._baseurl, table_name)
211+
req_data = {}
212+
res = requests.delete(
213+
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
214+
)
215+
status_code = res.status_code
216+
body = res.json()
217+
res.close()
218+
return status_code, body
110219

111-
## insert data into table
220+
# Insert data into table
112221
def insert(self, table_name: str, records: list[dict]):
113222
req_url = "{}/data/insert".format(self._baseurl)
114223
req_data = {"table": table_name, "data": records}
224+
print("req_url: ", req_url)
225+
print("req_data: ", req_data)
115226
res = requests.post(
116227
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
117228
)
@@ -120,7 +231,7 @@ def insert(self, table_name: str, records: list[dict]):
120231
res.close()
121232
return status_code, body
122233

123-
## query data from table
234+
# Query data from table
124235
def query(
125236
self,
126237
table_name: str,
@@ -154,7 +265,7 @@ def query(
154265
res.close()
155266
return status_code, body
156267

157-
## delete data from table
268+
# Delete data from table
158269
def delete(
159270
self,
160271
table_name: str,
@@ -163,14 +274,14 @@ def delete(
163274
filter: Optional[str] = None,
164275
):
165276
"""Epsilla supports delete records by primary keys as default for now."""
166-
if filter == None:
167-
if primary_keys == None and ids == None:
277+
if filter is None:
278+
if primary_keys is None and ids is None:
168279
raise Exception(
169280
"[ERROR] Please provide at least one of primary keys(ids) and filter to delete record(s)."
170281
)
171-
if primary_keys == None and ids != None:
282+
if primary_keys is None and ids is not None:
172283
primary_keys = ids
173-
if primary_keys != None and ids != None:
284+
if primary_keys is not None and ids is not None:
174285
try:
175286
sentry_sdk.sdk("Duplicate Keys with both primary keys and ids", "info")
176287
except Exception as e:
@@ -181,9 +292,9 @@ def delete(
181292

182293
req_url = "{}/data/delete".format(self._baseurl)
183294
req_data = {"table": table_name}
184-
if primary_keys != None:
295+
if primary_keys is not None:
185296
req_data["primaryKeys"] = primary_keys
186-
if filter != None:
297+
if filter is not None:
187298
req_data["filter"] = filter
188299

189300
res = requests.post(
@@ -194,7 +305,7 @@ def delete(
194305
res.close()
195306
return status_code, body
196307

197-
## get data from table
308+
# Get data from table
198309
def get(
199310
self,
200311
table_name: str,
@@ -206,27 +317,27 @@ def get(
206317
limit: Optional[int] = None,
207318
):
208319
"""Epsilla supports get records by primary keys as default for now."""
209-
if primary_keys != None and ids != None:
320+
if primary_keys is not None and ids is not None:
210321
try:
211322
sentry_sdk.sdk("Duplicate Keys with both primary_keys and ids", "info")
212323
except Exception as e:
213324
pass
214325
print(
215326
"[WARN]Both primary_keys and ids are prvoided, will use primary keys by default!"
216327
)
217-
if primary_keys == None and ids != None:
328+
if primary_keys is None and ids is not None:
218329
primary_keys = ids
219330

220331
req_data = {"table": table_name}
221-
if response_fields != None:
332+
if response_fields is not None:
222333
req_data["response"] = response_fields
223-
if primary_keys != None:
334+
if primary_keys is not None:
224335
req_data["primaryKeys"] = primary_keys
225-
if filter != None:
336+
if filter is not None:
226337
req_data["filter"] = filter
227-
if skip != None:
338+
if skip is not None:
228339
req_data["skip"] = skip
229-
if limit != None:
340+
if limit is not None:
230341
req_data["limit"] = limit
231342

232343
req_url = "{}/data/get".format(self._baseurl)

0 commit comments

Comments
 (0)