|
| 1 | +""" |
| 2 | +This module provides NeedleCollections class for interacting with Needle API's collections endpoint. |
| 3 | +""" |
| 4 | + |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import requests |
| 8 | + |
| 9 | +from needle.v1.models import ( |
| 10 | + NeedleConfig, |
| 11 | + NeedleBaseClient, |
| 12 | + Collection, |
| 13 | + Error, |
| 14 | + SearchResult, |
| 15 | +) |
| 16 | +from needle.v1.collections.files import NeedleCollectionsFiles |
| 17 | + |
| 18 | + |
| 19 | +class NeedleCollections(NeedleBaseClient): |
| 20 | + """ |
| 21 | + A client for interacting with the Needle API's collections endpoint. |
| 22 | +
|
| 23 | + This class provides methods to create and manage collections within the Needle API. |
| 24 | + It uses a requests session to handle HTTP requests with a default timeout of 120 seconds. |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__(self, config: NeedleConfig, headers: dict): |
| 28 | + super().__init__(config, headers) |
| 29 | + |
| 30 | + self.endpoint = f"{config.url}/api/v1/collections" |
| 31 | + self.search_endpoint = f"{config.search_url}/api/v1/collections" |
| 32 | + |
| 33 | + # requests config |
| 34 | + self.session = requests.Session() |
| 35 | + self.session.headers.update(headers) |
| 36 | + self.session.timeout = 120 |
| 37 | + |
| 38 | + # sub clients |
| 39 | + self.files = NeedleCollectionsFiles(config, headers) |
| 40 | + |
| 41 | + def create(self, name: str, file_ids: Optional[list[str]] = None): |
| 42 | + """ |
| 43 | + Creates a new collection with the specified name and file IDs. |
| 44 | +
|
| 45 | + Args: |
| 46 | + name (str): The name of the collection. |
| 47 | + file_ids (Optiona[list[str]]): A list of file IDs to include in the collection. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + Collection: The created collection object. |
| 51 | +
|
| 52 | + Raises: |
| 53 | + Error: If the API request fails. |
| 54 | + """ |
| 55 | + req_body = {"name": name, "file_ids": file_ids} |
| 56 | + resp = self.session.post( |
| 57 | + f"{self.endpoint}", |
| 58 | + json=req_body, |
| 59 | + ) |
| 60 | + body = resp.json() |
| 61 | + if resp.status_code >= 400: |
| 62 | + error = body.get("error") |
| 63 | + raise Error(**error) |
| 64 | + c = body.get("result") |
| 65 | + return Collection(**c) |
| 66 | + |
| 67 | + def get(self, collection_id: str): |
| 68 | + """ |
| 69 | + Retrieves a collection by its ID. |
| 70 | +
|
| 71 | + Args: |
| 72 | + collection_id (str): The ID of the collection to retrieve. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + Collection: The retrieved collection object. |
| 76 | +
|
| 77 | + Raises: |
| 78 | + Error: If the API request fails. |
| 79 | + """ |
| 80 | + resp = self.session.get(f"{self.endpoint}/{collection_id}") |
| 81 | + body = resp.json() |
| 82 | + if resp.status_code >= 400: |
| 83 | + error = body.get("error") |
| 84 | + raise Error(**error) |
| 85 | + c = body.get("result") |
| 86 | + return Collection(**c) |
| 87 | + |
| 88 | + def list(self): |
| 89 | + """ |
| 90 | + Lists all collections. |
| 91 | +
|
| 92 | + Returns: |
| 93 | + list[Collection]: A list of all collections. |
| 94 | +
|
| 95 | + Raises: |
| 96 | + Error: If the API request fails. |
| 97 | + """ |
| 98 | + resp = self.session.get(self.endpoint) |
| 99 | + body = resp.json() |
| 100 | + if resp.status_code >= 400: |
| 101 | + error = body.get("error") |
| 102 | + raise Error(**error) |
| 103 | + return [Collection(**c) for c in body.get("result")] |
| 104 | + |
| 105 | + def search(self, collection_id: str, text: str): |
| 106 | + """ |
| 107 | + Searches within a collection based on the provided parameters. |
| 108 | +
|
| 109 | + Args: |
| 110 | + params (SearchCollectionRequest): The search parameters. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + list[dict]: The search results. |
| 114 | +
|
| 115 | + Raises: |
| 116 | + Error: If the API request fails. |
| 117 | + """ |
| 118 | + endpoint = f"{self.search_endpoint}/{collection_id}/search" |
| 119 | + req_body = {"text": text} |
| 120 | + resp = self.session.post(endpoint, headers=self.headers, json=req_body) |
| 121 | + body = resp.json() |
| 122 | + if resp.status_code >= 400: |
| 123 | + error = body.get("error") |
| 124 | + raise Error(**error) |
| 125 | + return [SearchResult(**c) for c in body.get("result")] |
0 commit comments