|
7 | 7 | import socket |
8 | 8 | import time |
9 | 9 | from typing import Optional, Union |
| 10 | +import asyncio |
| 11 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
10 | 12 |
|
11 | 13 | class VectorRetriever: |
12 | 14 | def __init__( |
@@ -138,17 +140,29 @@ def set_reranker(self, type: str="rrf", weights: list[float] = None, k = 50, lim |
138 | 140 | self._reranker = RRFReRanker(weights=weights, k=k, limit=limit) |
139 | 141 | return self |
140 | 142 |
|
141 | | - def search(self, query: str) -> list[dict]: |
| 143 | + async def search(self, query: str) -> list[dict]: |
142 | 144 | # If no retriever is added, return error |
143 | 145 | if not self._retrievers: |
144 | 146 | raise Exception("No retriever added to the search engine") |
145 | 147 | # If more than one retrievers are added, must set a reranker |
146 | 148 | if len(self._retrievers) > 1 and not self._reranker: |
147 | 149 | raise Exception("More than one retriever added to the search engine, but no reranker is set") |
148 | | - # Retrieve candidates from each retriever |
| 150 | + |
| 151 | + # Function to wrap synchronous call in async coroutine |
| 152 | + def run_retriever(retriever): |
| 153 | + return retriever.retrieve(query) |
| 154 | + # Use ThreadPoolExecutor to run retrievers concurrently |
149 | 155 | candidates = [] |
150 | | - for retriever in self._retrievers: |
151 | | - candidates.append(retriever.retrieve(query)) |
| 156 | + with ThreadPoolExecutor(max_workers=len(self._retrievers)) as executor: |
| 157 | + # Schedule the execution of each retriever and immediately return future objects |
| 158 | + future_to_retriever = {executor.submit(run_retriever, retriever): retriever for retriever in self._retrievers} |
| 159 | + |
| 160 | + for future in as_completed(future_to_retriever): |
| 161 | + try: |
| 162 | + result = future.result() |
| 163 | + candidates.append(result) |
| 164 | + except Exception as exc: |
| 165 | + print(f'Retriever generated an exception: {exc}') |
152 | 166 |
|
153 | 167 | # Rerank candidates if reranker is set |
154 | 168 | if self._reranker: |
|
0 commit comments