Skip to content

Commit 20ed386

Browse files
authored
Explicitly create response objects (#1)
1 parent 657362f commit 20ed386

File tree

4 files changed

+189
-126
lines changed

4 files changed

+189
-126
lines changed

needle/v1/collections/__init__.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,15 @@ def create(self, name: str, file_ids: Optional[list[str]] = None):
6262
error = body.get("error")
6363
raise Error(**error)
6464
c = body.get("result")
65-
return Collection(**c)
65+
return Collection(
66+
id=c.get("id"),
67+
name=c.get("name"),
68+
embedding_model=c.get("embedding_model"),
69+
embedding_dimensions=c.get("embedding_dimensions"),
70+
search_queries=c.get("search_queries"),
71+
created_at=c.get("created_at"),
72+
updated_at=c.get("updated_at"),
73+
)
6674

6775
def get(self, collection_id: str):
6876
"""
@@ -83,7 +91,15 @@ def get(self, collection_id: str):
8391
error = body.get("error")
8492
raise Error(**error)
8593
c = body.get("result")
86-
return Collection(**c)
94+
return Collection(
95+
id=c.get("id"),
96+
name=c.get("name"),
97+
embedding_model=c.get("embedding_model"),
98+
embedding_dimensions=c.get("embedding_dimensions"),
99+
search_queries=c.get("search_queries"),
100+
created_at=c.get("created_at"),
101+
updated_at=c.get("updated_at"),
102+
)
87103

88104
def list(self):
89105
"""
@@ -100,7 +116,18 @@ def list(self):
100116
if resp.status_code >= 400:
101117
error = body.get("error")
102118
raise Error(**error)
103-
return [Collection(**c) for c in body.get("result")]
119+
return [
120+
Collection(
121+
id=c.get("id"),
122+
name=c.get("name"),
123+
embedding_model=c.get("embedding_model"),
124+
embedding_dimensions=c.get("embedding_dimensions"),
125+
search_queries=c.get("search_queries"),
126+
created_at=c.get("created_at"),
127+
updated_at=c.get("updated_at"),
128+
)
129+
for c in body.get("result")
130+
]
104131

105132
def search(self, collection_id: str, text: str):
106133
"""
@@ -122,4 +149,10 @@ def search(self, collection_id: str, text: str):
122149
if resp.status_code >= 400:
123150
error = body.get("error")
124151
raise Error(**error)
125-
return [SearchResult(**c) for c in body.get("result")]
152+
return [
153+
SearchResult(
154+
content=r.get("content"),
155+
file_id=r.get("file_id"),
156+
)
157+
for r in body.get("result")
158+
]

needle/v1/collections/files.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,22 @@ def add(self, collection_id: str, files: list[FileToAdd]):
5050
if resp.status_code >= 400:
5151
error = body.get("error")
5252
raise Error(**error)
53-
return [CollectionFile(**cf) for cf in body.get("result")]
53+
return [
54+
CollectionFile(
55+
id=cf.get("id"),
56+
name=cf.get("name"),
57+
type=cf.get("type"),
58+
url=cf.get("url"),
59+
user_id=cf.get("user_id"),
60+
connector_id=cf.get("connector_id"),
61+
size=cf.get("size"),
62+
md5_hash=cf.get("md5_hash"),
63+
created_at=cf.get("created_at"),
64+
updated_at=cf.get("updated_at"),
65+
status=cf.get("status"),
66+
)
67+
for cf in body.get("result")
68+
]
5469

5570
def list(self, collection_id: str):
5671
"""
@@ -71,4 +86,19 @@ def list(self, collection_id: str):
7186
if resp.status_code >= 400:
7287
error = body.get("error")
7388
raise Error(**error)
74-
return [CollectionFile(**cf) for cf in body.get("result")]
89+
return [
90+
CollectionFile(
91+
id=cf.get("id"),
92+
name=cf.get("name"),
93+
type=cf.get("type"),
94+
url=cf.get("url"),
95+
user_id=cf.get("user_id"),
96+
connector_id=cf.get("connector_id"),
97+
size=cf.get("size"),
98+
md5_hash=cf.get("md5_hash"),
99+
created_at=cf.get("created_at"),
100+
updated_at=cf.get("updated_at"),
101+
status=cf.get("status"),
102+
)
103+
for cf in body.get("result")
104+
]

0 commit comments

Comments
 (0)