Skip to content

Commit 1b71ead

Browse files
Hybrid search
1 parent bfc7460 commit 1b71ead

3 files changed

Lines changed: 104 additions & 21 deletions

File tree

pyepsilla/cloud/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pprint
88
import socket
99
from typing import Optional, Union
10+
from ..utils.search_engine import SearchEngine
1011

1112
import requests
1213
import sentry_sdk
@@ -311,3 +312,6 @@ def get(
311312
body = res.json()
312313
res.close()
313314
return status_code, body
315+
316+
def as_search_engine(self):
317+
return SearchEngine(self)

pyepsilla/enterprise/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pprint
88
import socket
99
from typing import Optional, Union
10+
from ..utils.search_engine import SearchEngine
1011

1112
import requests
1213
import sentry_sdk
@@ -360,3 +361,6 @@ def get(
360361
body = res.json()
361362
res.close()
362363
return status_code, body
364+
365+
def as_search_engine(self):
366+
return SearchEngine(self)

pyepsilla/utils/search_engine.py

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import socket
88
import time
99
from typing import Optional, Union
10-
import asyncio
11-
from concurrent.futures import ThreadPoolExecutor, as_completed
1210

1311
class VectorRetriever:
1412
def __init__(
@@ -96,6 +94,87 @@ def rerank(self, candidates: list[list[any]]) -> list[any]:
9694
# Return only the candidate information, discarding the scores
9795
return [item["candidate"] for item in sorted_candidates]
9896

97+
class RelativeScoreFusionReranker(Reranker):
98+
def __init__(self, limit: int = None):
99+
self._limit = limit
100+
101+
def normalize_distances(self, candidates: list[dict]) -> list[dict]:
102+
# Extract all distances
103+
distances = [candidate["@distance"] for candidate in candidates]
104+
105+
if len(distances) < 2 or max(distances) == min(distances):
106+
return [{'candidate': candidate, 'score': 1} for candidate in candidates]
107+
108+
min_distance, max_distance = min(distances), max(distances)
109+
110+
# Normalize distances: (distance - min_distance) / (max_distance - min_distance)
111+
normalized_candidates = []
112+
for candidate in candidates:
113+
normalized_score = (candidate["@distance"] - min_distance) / (max_distance - min_distance)
114+
normalized_candidates.append({'candidate': candidate, 'score': 1 - normalized_score})
115+
116+
return normalized_candidates
117+
118+
def rerank(self, candidates: list[list[dict]]) -> list[dict]:
119+
normalized_lists = [self.normalize_distances(candidate_list) for candidate_list in candidates]
120+
121+
# Aggregate normalized scores across lists
122+
aggregated_scores = {}
123+
for candidate_list in normalized_lists:
124+
for item in candidate_list:
125+
candidate_id = item['candidate']['@id']
126+
if candidate_id in aggregated_scores:
127+
aggregated_scores[candidate_id]['score'] += item['score']
128+
else:
129+
aggregated_scores[candidate_id] = item
130+
131+
# Sort candidates based on aggregated score
132+
sorted_candidates = sorted(aggregated_scores.values(), key=lambda x: x['score'], reverse=True)
133+
134+
# Apply the limit to the final list if specified
135+
if self._limit is not None:
136+
sorted_candidates = sorted_candidates[:self._limit]
137+
138+
# Return only the candidate information, discarding the scores
139+
return [item['candidate'] for item in sorted_candidates]
140+
141+
class DistributionBasedScoreFusionReranker(Reranker):
142+
def __init__(self, scale_ranges: list[list[float]] = [], limit: int = None):
143+
self._limit = limit
144+
self._scale_ranges = scale_ranges
145+
146+
def normalize_distances(self, scale: list[float], candidates: list[dict]) -> list[dict]:
147+
# Normalize distances: (distance - min_distance) / (max_distance - min_distance)
148+
normalized_candidates = []
149+
for candidate in candidates:
150+
normalized_score = max(candidate["@distance"] - scale[0], 0) / (scale[1] - scale[0])
151+
normalized_candidates.append({'candidate': candidate, 'score': 1 - min(1, normalized_score)})
152+
153+
return normalized_candidates
154+
155+
def rerank(self, candidates: list[list[dict]]) -> list[dict]:
156+
normalized_lists = [self.normalize_distances(self._scale_ranges[i], candidate_list) for i, candidate_list in enumerate(candidates)]
157+
158+
# Aggregate normalized scores across lists
159+
aggregated_scores = {}
160+
for candidate_list in normalized_lists:
161+
for item in candidate_list:
162+
candidate_id = item['candidate']['@id']
163+
if candidate_id in aggregated_scores:
164+
aggregated_scores[candidate_id]['score'] += item['score']
165+
else:
166+
aggregated_scores[candidate_id] = item
167+
168+
# Sort candidates based on aggregated score
169+
sorted_candidates = sorted(aggregated_scores.values(), key=lambda x: x['score'], reverse=True)
170+
171+
# Apply the limit to the final list if specified
172+
if self._limit is not None:
173+
sorted_candidates = sorted_candidates[:self._limit]
174+
175+
# Return only the candidate information, discarding the scores
176+
return [item['candidate'] for item in sorted_candidates]
177+
99178
class SearchEngine:
100179
def __init__(
101180
self,
@@ -132,37 +211,33 @@ def add_retriever(
132211
)
133212
return self
134213

135-
def set_reranker(self, type: str="rrf", weights: list[float] = None, k = 50, limit = None):
136-
# The length of weights should be equal to the number of retrievers
137-
if weights is not None and len(self._retrievers) != len(weights):
138-
raise Exception("The length of weights should be equal to the number of retrievers")
139-
if type == "rrf":
214+
def set_reranker(self, type: str="rrf", weights: list[float] = None, scale_ranges: list[list[int]] = [], k = 50, limit = None):
215+
if type == "rrf" or type == "reciprocal_rank_fusion":
216+
if weights is not None and len(self._retrievers) != len(weights):
217+
raise Exception("The length of weights should be equal to the number of retrievers")
140218
self._reranker = RRFReRanker(weights=weights, k=k, limit=limit)
219+
elif type == "rsf" or type == "relative_score_fusion":
220+
self._reranker = RelativeScoreFusionReranker(limit=limit)
221+
elif type == "dbsf" or type == "distribution_based_score_fusion":
222+
if len(scale_ranges) != len(self._retrievers):
223+
raise Exception("The length of scale_ranges should be equal to the number of retrievers")
224+
self._reranker = DistributionBasedScoreFusionReranker(scale_ranges, limit=limit)
225+
else:
226+
raise Exception("Invalid reranker type: " + type)
141227
return self
142228

143-
async def search(self, query: str) -> list[dict]:
229+
def search(self, query: str) -> list[dict]:
144230
# If no retriever is added, return error
145231
if not self._retrievers:
146232
raise Exception("No retriever added to the search engine")
147233
# If more than one retrievers are added, must set a reranker
148234
if len(self._retrievers) > 1 and not self._reranker:
149235
raise Exception("More than one retriever added to the search engine, but no reranker is set")
150236

151-
# Function to wrap synchronous call in async coroutine
152-
def run_retriever(retriever):
153-
return retriever.retrieve(query)
154237
# Use ThreadPoolExecutor to run retrievers concurrently
155238
candidates = []
156-
with ThreadPoolExecutor(max_workers=len(self._retrievers)) as executor:
157-
# Schedule the execution of each retriever and immediately return future objects
158-
future_to_retriever = {executor.submit(run_retriever, retriever): retriever for retriever in self._retrievers}
159-
160-
for future in as_completed(future_to_retriever):
161-
try:
162-
result = future.result()
163-
candidates.append(result)
164-
except Exception as exc:
165-
print(f'Retriever generated an exception: {exc}')
239+
for retriever in self._retrievers:
240+
candidates.append(retriever.retrieve(query))
166241

167242
# Rerank candidates if reranker is set
168243
if self._reranker:

0 commit comments

Comments
 (0)