Skip to content

Commit f99ea69

Browse files
committed
fix custom tests
1 parent c2d74f8 commit f99ea69

4 files changed

Lines changed: 25 additions & 27 deletions

File tree

tests/test_async_client.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,6 @@ async def test_moved_fn(self) -> None:
7676
await self.co.list_connectors("dummy", dummy="dummy") # type: ignore
7777

7878

79-
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
80-
async def test_generate(self) -> None:
81-
response = await self.co.generate(
82-
prompt='Please explain to me how LLMs work',
83-
)
84-
print(response)
85-
8679
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
8780
async def test_embed(self) -> None:
8881
response = await self.co.embed(

tests/test_bedrock_client.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,20 @@ def _setup_boto3_env():
2424
@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set")
2525
class TestClient(unittest.TestCase):
2626
platform: str = "bedrock"
27-
client: cohere.AwsClient = cohere.BedrockClient(
28-
aws_access_key=aws_access_key,
29-
aws_secret_key=aws_secret_key,
30-
aws_session_token=aws_session_token,
31-
aws_region=aws_region,
32-
)
3327
models: typing.Dict[str, str] = {
3428
"chat_model": "cohere.command-r-plus-v1:0",
3529
"embed_model": "cohere.embed-multilingual-v3",
3630
"generate_model": "cohere.command-text-v14",
3731
}
3832

33+
def setUp(self) -> None:
34+
self.client = cohere.BedrockClient(
35+
aws_access_key=aws_access_key,
36+
aws_secret_key=aws_secret_key,
37+
aws_session_token=aws_session_token,
38+
aws_region=aws_region,
39+
)
40+
3941
def test_rerank(self) -> None:
4042
if self.platform != "sagemaker":
4143
self.skipTest("Only sagemaker supports rerank")
@@ -130,12 +132,13 @@ class TestBedrockClientV2(unittest.TestCase):
130132
since the request would fail with a signature mismatch otherwise.
131133
"""
132134

133-
client: cohere.ClientV2 = cohere.BedrockClientV2(
134-
aws_access_key=aws_access_key,
135-
aws_secret_key=aws_secret_key,
136-
aws_session_token=aws_session_token,
137-
aws_region=aws_region,
138-
)
135+
def setUp(self) -> None:
136+
self.client = cohere.BedrockClientV2(
137+
aws_access_key=aws_access_key,
138+
aws_secret_key=aws_secret_key,
139+
aws_session_token=aws_session_token,
140+
aws_region=aws_region,
141+
)
139142

140143
def test_embed(self) -> None:
141144
response = self.client.embed(

tests/test_client.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,6 @@ def test_moved_fn(self) -> None:
7272
with self.assertRaises(ValueError):
7373
co.list_connectors("dummy", dummy="dummy") # type: ignore
7474

75-
@unittest.skipIf(os.getenv("CO_API_URL") is not None, "Doesn't work in staging.")
76-
def test_generate(self) -> None:
77-
response = co.generate(
78-
prompt='Please explain to me how LLMs work',
79-
)
80-
print(response)
81-
8275
def test_embed(self) -> None:
8376
response = co.embed(
8477
texts=['hello', 'goodbye'],

tests/test_client_init.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@
55
import cohere
66
from cohere import ToolMessage, UserMessage, AssistantMessage
77

8+
try:
9+
import boto3
10+
HAS_BOTO3 = True
11+
except ImportError:
12+
HAS_BOTO3 = False
13+
814
class TestClientInit(unittest.TestCase):
9-
def test_inits(self) -> None:
15+
@unittest.skipUnless(HAS_BOTO3, "boto3 not installed")
16+
def test_aws_inits(self) -> None:
1017
cohere.BedrockClient()
1118
cohere.BedrockClientV2()
1219
cohere.SagemakerClient()
1320
cohere.SagemakerClientV2()
21+
22+
def test_inits(self) -> None:
1423
cohere.Client(api_key="n/a")
1524
cohere.ClientV2(api_key="n/a")
1625

0 commit comments

Comments
 (0)