Skip to content

Commit 0a7eac9

Browse files
fern-supportThomas Bakerjasonozuzu-cohere
authored
Fix generate_stream error in ClientV2 (#725)
* add new proxy for the raw_client param * fix type mismatch between raw_client and combined_raw_client * add tests for legacy methods --------- Co-authored-by: Thomas Baker <thomas@buildwithfern.com> Co-authored-by: Jason Ozuzu <jasonozuzu@cohere.com>
1 parent f366233 commit 0a7eac9

2 files changed

Lines changed: 36 additions & 5 deletions

File tree

src/cohere/client_v2.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
1-
from .client import Client, AsyncClient
2-
from .v2.client import V2Client, AsyncV2Client
3-
import typing
4-
from .environment import ClientEnvironment
51
import os
6-
import httpx
2+
import typing
73
from concurrent.futures import ThreadPoolExecutor
84

5+
import httpx
6+
from .client import AsyncClient, Client
7+
from .environment import ClientEnvironment
8+
from .v2.client import AsyncRawV2Client, AsyncV2Client, RawV2Client, V2Client
9+
10+
11+
class _CombinedRawClient:
12+
"""Proxy that combines v1 and v2 raw clients.
13+
14+
V2Client and Client both assign to self._raw_client in __init__,
15+
causing a collision when combined in ClientV2/AsyncClientV2.
16+
This proxy delegates to v2 first, falling back to v1 for
17+
legacy methods like generate_stream.
18+
"""
19+
20+
def __init__(self, v1_raw_client: typing.Any, v2_raw_client: typing.Any):
21+
self._v1 = v1_raw_client
22+
self._v2 = v2_raw_client
23+
24+
def __getattr__(self, name: str) -> typing.Any:
25+
try:
26+
return getattr(self._v2, name)
27+
except AttributeError:
28+
return getattr(self._v1, name)
29+
930

1031
class ClientV2(V2Client, Client): # type: ignore
1132
def __init__(
@@ -32,10 +53,12 @@ def __init__(
3253
thread_pool_executor=thread_pool_executor,
3354
log_warning_experimental_features=log_warning_experimental_features,
3455
)
56+
v1_raw = self._raw_client
3557
V2Client.__init__(
3658
self,
3759
client_wrapper=self._client_wrapper
3860
)
61+
self._raw_client = typing.cast(RawV2Client, _CombinedRawClient(v1_raw, self._raw_client))
3962

4063

4164
class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore
@@ -63,7 +86,9 @@ def __init__(
6386
thread_pool_executor=thread_pool_executor,
6487
log_warning_experimental_features=log_warning_experimental_features,
6588
)
89+
v1_raw = self._raw_client
6690
AsyncV2Client.__init__(
6791
self,
6892
client_wrapper=self._client_wrapper
6993
)
94+
self._raw_client = typing.cast(AsyncRawV2Client, _CombinedRawClient(v1_raw, self._raw_client))

tests/test_client_v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def test_chat_stream(self) -> None:
3636
self.assertTrue("content-delta" in events)
3737
self.assertTrue("content-end" in events)
3838
self.assertTrue("message-end" in events)
39+
40+
def test_legacy_methods_available(self) -> None:
41+
self.assertTrue(hasattr(co, "generate"))
42+
self.assertTrue(callable(getattr(co, "generate")))
43+
self.assertTrue(hasattr(co, "generate_stream"))
44+
self.assertTrue(callable(getattr(co, "generate_stream")))
3945

4046
@unittest.skip("Skip v2 test for now")
4147
def test_chat_documents(self) -> None:

0 commit comments

Comments
 (0)