|
7 | 7 | # 3. wget http://ann-benchmarks.com/gist-960-euclidean.hdf5 |
8 | 8 | # 4. python3 gist-960-euclidean.py |
9 | 9 |
|
10 | | -from pyepsilla import vectordb |
11 | | -import os, h5py, datetime |
| 10 | +import datetime |
| 11 | +import os |
12 | 12 | from urllib.parse import urlparse |
13 | 13 |
|
14 | | -## Connect to Epsilla vector database |
15 | | -client = vectordb.Client(host='127.0.0.1', port='8888') |
16 | | -client.load_db(db_name="benchmark", db_path="/tmp/epsilla", vector_scale=1000000, wal_enabled=False) ## pay attention to change db_path to persistent volume for production environment |
| 14 | +import h5py |
| 15 | +from pyepsilla import vectordb |
| 16 | + |
| 17 | +# Connect to Epsilla vector database |
| 18 | +client = vectordb.Client(host="127.0.0.1", port="8888") |
| 19 | +client.load_db( |
| 20 | + db_name="benchmark", db_path="/tmp/epsilla", vector_scale=1000000, wal_enabled=False |
| 21 | +) # pay attention to change db_path to persistent volume for production environment |
17 | 22 | client.use_db(db_name="benchmark") |
18 | 23 |
|
19 | | -## Check gist-960-euclidean dataset hdf5 file to download or not |
| 24 | +# Check gist-960-euclidean dataset hdf5 file to download or not |
20 | 25 | dataset_download_url = "http://ann-benchmarks.com/gist-960-euclidean.hdf5" |
21 | 26 | dataset_filename = os.path.basename(urlparse(dataset_download_url).path) |
22 | 27 | if not os.path.isfile(dataset_filename): |
23 | 28 | os.system("wget --no-check-certificate {}".format(dataset_download_url)) |
24 | 29 |
|
25 | | -## Read gist-960-euclidean data from hdf5 |
26 | | -f = h5py.File('gist-960-euclidean.hdf5', 'r') |
| 30 | +# Read gist-960-euclidean data from hdf5 |
| 31 | +f = h5py.File("gist-960-euclidean.hdf5", "r") |
27 | 32 | print(list(f.keys())) |
28 | 33 | training_data = f["train"] |
29 | 34 | size = training_data.size |
30 | 35 | records_num, dimensions = training_data.shape |
31 | 36 |
|
32 | | -## Create table for gist-960-euclidean |
| 37 | +# Create table for gist-960-euclidean |
33 | 38 | id_field = {"name": "id", "dataType": "INT", "primaryKey": True} |
34 | 39 | vec_field = {"name": "vector", "dataType": "VECTOR_FLOAT", "dimensions": dimensions} |
35 | 40 | fields = [id_field, vec_field] |
36 | 41 | status_code, response = client.create_table(table_name="benchmark", table_fields=fields) |
37 | 42 |
|
38 | | -## Insert 20000 data into table |
39 | | -records_data = [ {"id": i, "vector": training_data[i].tolist()} for i in range(10000)] |
| 43 | +# Insert 20000 data into table |
| 44 | +records_data = [{"id": i, "vector": training_data[i].tolist()} for i in range(10000)] |
40 | 45 | client.insert(table_name="benchmark", records=records_data) |
41 | 46 |
|
42 | | -## Insert all data into table |
43 | | -indexs = [ i for i in range(0, records_num+10000, 50000)] |
| 47 | +# Insert all data into table |
| 48 | +indexs = [i for i in range(0, records_num + 10000, 50000)] |
44 | 49 | print("Begin to insert all gist data into table ...") |
45 | | -for i in range(len(indexs)-1): |
46 | | - print("-"*20) |
47 | | - start=datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") |
48 | | - print(indexs[i], indexs[i+1]) |
49 | | - records_data = [{"id": i, "vector": training_data[i].tolist()} for i in range(indexs[i], indexs[i+1])] |
| 50 | +for i in range(len(indexs) - 1): |
| 51 | + print("-" * 20) |
| 52 | + start = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") |
| 53 | + print(indexs[i], indexs[i + 1]) |
| 54 | + records_data = [ |
| 55 | + {"id": i, "vector": training_data[i].tolist()} |
| 56 | + for i in range(indexs[i], indexs[i + 1]) |
| 57 | + ] |
50 | 58 | client.insert(table_name="benchmark", records=records_data) |
51 | 59 | end = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") |
52 | 60 | print("START:", start, "\nEND :", end) |
53 | 61 |
|
54 | 62 |
|
55 | | -## Delete some data by ids |
| 63 | +# Delete some data by ids |
56 | 64 | # client.delete(table_name="benchmark", ids=[300033, 600066]) |
57 | 65 | client.delete(table_name="benchmark", ids=[9999]) |
58 | 66 |
|
59 | 67 |
|
60 | | -## Rebuild ann graph, it will wait until rebuild is finished, wait time is depended on the amount of dataset |
| 68 | +# Rebuild ann graph, it will wait until rebuild is finished, wait time is depended on the amount of dataset |
61 | 69 | client.rebuild() |
62 | 70 |
|
63 | | -## Query vectors |
| 71 | +# Query vectors |
64 | 72 | query_field = "vector" |
65 | 73 | query_vector = training_data[40000].tolist() |
66 | 74 | response_fields = ["id"] |
67 | 75 | limit = 2 |
68 | 76 |
|
69 | | -status_code, response = client.query(table_name="benchmark", query_field=query_field, query_vector=query_vector, response_fields=response_fields, limit=limit, with_distance=True) |
| 77 | +status_code, response = client.query( |
| 78 | + table_name="benchmark", |
| 79 | + query_field=query_field, |
| 80 | + query_vector=query_vector, |
| 81 | + response_fields=response_fields, |
| 82 | + limit=limit, |
| 83 | + with_distance=True, |
| 84 | +) |
70 | 85 | print("Response:", response) |
71 | 86 |
|
72 | 87 |
|
73 | | -## Get |
| 88 | +# Get |
74 | 89 | status_code, body = client.get(table_name="benchmark") |
75 | 90 | print("Status Code:", status_code) |
76 | 91 | print("Size of result gotten", len(body["result"])) |
77 | | - |
78 | | - |
|
0 commit comments