Skip to content

Commit 9620390

Browse files
Search engine API
1 parent c91a9b4 commit 9620390

4 files changed

Lines changed: 162 additions & 1 deletion

File tree

pyepsilla/utils/__init__.py

Whitespace-only changes.

pyepsilla/utils/search_engine.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/usr/bin/env python
2+
# -*- coding:utf-8 -*-
3+
from __future__ import annotations
4+
5+
import datetime
6+
import json
7+
import socket
8+
import time
9+
from typing import Optional, Union
10+
11+
class VectorRetriever:
12+
def __init__(
13+
self,
14+
db_client,
15+
table_name: str,
16+
primary_key_field: str,
17+
query_index: str = None,
18+
query_field: str = None,
19+
query_vector: Union[list, dict] = None,
20+
response_fields: list = None,
21+
limit: int = 2,
22+
filter: str = ""
23+
):
24+
self._db_client = db_client
25+
self._table_name = table_name
26+
self._primary_key_field = primary_key_field
27+
self._query_index = query_index
28+
self._query_field = query_field
29+
self._query_vector = query_vector
30+
self._response_fields = response_fields
31+
self._limit = limit
32+
self._filter = filter
33+
34+
def retrieve(self, query: str) -> list[dict]:
35+
# Query vectors from the table
36+
status_code, response = self._db_client.query(
37+
table_name=self._table_name,
38+
query_text=query,
39+
query_index=self._query_index,
40+
query_field=self._query_field,
41+
query_vector=self._query_vector,
42+
response_fields=self._response_fields,
43+
limit=self._limit,
44+
filter=self._filter,
45+
with_distance=True,
46+
)
47+
if status_code != 200:
48+
error_msg = response["message"] if "message" in response else "Unknown error"
49+
raise Exception(f"Failed to retrieve data from table {self._table_name}: {error_msg}")
50+
# Add @id from the table to each record based on the primary_key_field
51+
for record in response["result"]:
52+
# Raise exception if the primary_key_field is not found in the record
53+
if self._primary_key_field not in record:
54+
raise Exception(f"Primary key field {self._primary_key_field} not found in the response from table {self._table_name}")
55+
record["@id"] = record[self._primary_key_field]
56+
return response["result"]
57+
58+
class Reranker:
59+
def rerank(self, candidates: list[list[any]]) -> list[any]:
60+
pass
61+
62+
class RRFReRanker(Reranker):
63+
def __init__(self, weights: list[float] = None, k = 50, limit = None):
64+
self._weights = weights
65+
self._k = k
66+
self._limit = limit
67+
68+
def rerank(self, candidates: list[list[any]]) -> list[any]:
69+
# Use candidate["@distance"] of each candidate to rerank
70+
# Initialize weights if not provided
71+
if not self._weights:
72+
self._weights = [1] * len(candidates)
73+
74+
# Calculate RRF scores for each candidate
75+
rrf_scores = {}
76+
for i, candidate_list in enumerate(candidates):
77+
weight = self._weights[i]
78+
for rank, candidate in enumerate(candidate_list, start=1):
79+
# Calculate RRF score for this candidate in this list
80+
rrf_score = weight / (self._k + rank)
81+
# Aggregate scores if candidate appears in multiple lists
82+
if candidate["@id"] in rrf_scores:
83+
rrf_scores[candidate["@id"]]["score"] += rrf_score
84+
else:
85+
rrf_scores[candidate["@id"]] = {"candidate": candidate, "score": rrf_score}
86+
87+
# Sort candidates based on aggregated RRF score
88+
sorted_candidates = sorted(rrf_scores.values(), key=lambda x: x["score"], reverse=True)
89+
90+
# Apply the limit to the final list if specified
91+
if self._limit is not None:
92+
sorted_candidates = sorted_candidates[:self._limit]
93+
94+
# Return only the candidate information, discarding the scores
95+
return [item["candidate"] for item in sorted_candidates]
96+
97+
class SearchEngine:
98+
def __init__(
99+
self,
100+
db_client,
101+
):
102+
self._db_client = db_client
103+
self._retrievers = []
104+
self._reranker: Reranker = None
105+
106+
def add_retriever(
107+
self,
108+
table_name: str,
109+
primary_key_field: str = "ID",
110+
query_index: str = None,
111+
query_field: str = None,
112+
query_vector: Union[list, dict] = None,
113+
response_fields: list = None,
114+
limit: int = 2,
115+
filter: str = ""
116+
) -> SearchEngine:
117+
self._reranker = None
118+
self._retrievers.append(
119+
VectorRetriever(
120+
db_client=self._db_client,
121+
table_name=table_name,
122+
primary_key_field=primary_key_field,
123+
query_index=query_index,
124+
query_field=query_field,
125+
query_vector=query_vector,
126+
response_fields=response_fields,
127+
limit=limit,
128+
filter=filter
129+
)
130+
)
131+
return self
132+
133+
def set_reranker(self, type: str="rrf", weights: list[float] = None, k = 50, limit = None):
134+
# The length of weights should be equal to the number of retrievers
135+
if weights is not None and len(self._retrievers) != len(weights):
136+
raise Exception("The length of weights should be equal to the number of retrievers")
137+
if type == "rrf":
138+
self._reranker = RRFReRanker(weights=weights, k=k, limit=limit)
139+
return self
140+
141+
def search(self, query: str) -> list[dict]:
142+
# If no retriever is added, return error
143+
if not self._retrievers:
144+
raise Exception("No retriever added to the search engine")
145+
# If more than one retrievers are added, must set a reranker
146+
if len(self._retrievers) > 1 and not self._reranker:
147+
raise Exception("More than one retriever added to the search engine, but no reranker is set")
148+
# Retrieve candidates from each retriever
149+
candidates = []
150+
for retriever in self._retrievers:
151+
candidates.append(retriever.retrieve(query))
152+
153+
# Rerank candidates if reranker is set
154+
if self._reranker:
155+
candidates = self._reranker.rerank(candidates)
156+
157+
return candidates

pyepsilla/vectordb/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import socket
88
import time
99
from typing import Optional, Union
10+
from ..utils.search_engine import SearchEngine
1011

1112
import requests
1213
import sentry_sdk
@@ -354,3 +355,6 @@ def drop_db(self, db_name: str):
354355
body = res.json()
355356
res.close()
356357
return status_code, body
358+
359+
def as_search_engine(self):
360+
return SearchEngine(self)

pyepsilla/vectordb/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.3"
1+
__version__ = "0.3.4"

0 commit comments

Comments
 (0)