Skip to content

Commit 1cab6aa

Browse files
fern-supportclaudejsklan
authored
fix: resolve AWS client SigV4 signing, forced SageMaker dep, and missing embed params (#728)
* fix: resolve AWS client SigV4 signing, forced SageMaker dep, and missing embed params - Fix SigV4 host header mismatch: update copied headers dict with correct host after URL rewrite, so AWSRequest signs with the Bedrock/SageMaker host instead of stale api.cohere.com - Add mode parameter to cohere_aws.Client to conditionally initialize boto3 clients (bedrock-runtime/bedrock vs sagemaker-runtime/sagemaker), avoiding forced SageMaker dependency for Bedrock users - Add output_dimension and embedding_types params to embed() for Embed v4 Closes #721 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: add integration tests for AWS client fixes Add skipped integration tests (gated by TEST_AWS) covering: - BedrockClientV2 embed with SigV4 signing (validates host header fix) - cohere_aws.Client in Bedrock mode (validates mode param fix) - embed() with output_dimension and embedding_types (validates v4 params) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: guard SageMaker-only methods in Bedrock mode Address review feedback: In Bedrock mode, `self._sess` was never set, so SageMaker-only methods would throw confusing AttributeErrors. Now: - Initialize `_sess=None` and `_endpoint_name=None` in Bedrock mode - Add `_require_sagemaker()` guard to connect_to_endpoint, create_endpoint, export_finetune, summarize, and delete_endpoint Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add type annotation for TestCohereAwsBedrockClient.client Fixes mypy attr-defined error in CI. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: add mocked unit tests for AWS client fixes These run in CI without AWS credentials, covering: - SigV4 signing uses correct host header after URL rewrite - Mode-conditional boto3 client initialization (sagemaker vs bedrock) - Default mode is SAGEMAKER for backwards compat - embed() accepts, passes, and strips output_dimension/embedding_types Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: handle embedding_types dict response in AWS client embed methods When embedding_types is specified, the Cohere API returns embeddings as a dict (e.g. {"float": [[...]], "int8": [[...]]}) instead of a flat list. Both _bedrock_embed and _sagemaker_embed now detect the dict format and return it directly instead of wrapping it in Embeddings, which would silently produce wrong results for len() and iteration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Update sdk version * fix: correct assertions for dict response in embed integration tests When embedding_types is passed, _bedrock_embed returns a raw dict instead of an Embeddings object. Update test assertions to check for dict type and key presence instead of accessing .embeddings attribute. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: jsklan <jsklan.development@gmail.com>
1 parent b36875e commit 1cab6aa

6 files changed

Lines changed: 402 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ dynamic = ["version"]
44

55
[tool.poetry]
66
name = "cohere"
7-
version = "5.20.5"
7+
version = "5.20.6"
88
description = ""
99
readme = "README.md"
1010
authors = []

src/cohere/aws_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
239239
)
240240
request.url = URL(url)
241241
request.headers["host"] = request.url.host
242+
headers["host"] = request.url.host
242243

243244
if endpoint == "rerank":
244245
body["api_version"] = get_api_version(version=api_version)

src/cohere/core/client_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def get_headers(self) -> typing.Dict[str, str]:
2626
import platform
2727

2828
headers: typing.Dict[str, str] = {
29-
"User-Agent": "cohere/5.20.5",
29+
"User-Agent": "cohere/5.20.6",
3030
"X-Fern-Language": "Python",
3131
"X-Fern-Runtime": f"python/{platform.python_version()}",
3232
"X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}",
3333
"X-Fern-SDK-Name": "cohere",
34-
"X-Fern-SDK-Version": "5.20.5",
34+
"X-Fern-SDK-Version": "5.20.6",
3535
**(self.get_custom_headers() or {}),
3636
}
3737
if self._client_name is not None:

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,29 @@ class Client:
2020
def __init__(
2121
self,
2222
aws_region: typing.Optional[str] = None,
23+
mode: Mode = Mode.SAGEMAKER,
2324
):
2425
"""
2526
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
2627
`aws configure set region us-west-2` or override it with `region_name` parameter.
2728
"""
28-
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
29-
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
29+
self.mode = mode
3030
if os.environ.get('AWS_DEFAULT_REGION') is None:
3131
os.environ['AWS_DEFAULT_REGION'] = aws_region
32-
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
33-
self.mode = Mode.SAGEMAKER
3432

33+
if self.mode == Mode.SAGEMAKER:
34+
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
35+
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
36+
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
37+
elif self.mode == Mode.BEDROCK:
38+
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
39+
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
40+
self._sess = None
41+
self._endpoint_name = None
3542

43+
def _require_sagemaker(self) -> None:
44+
if self.mode != Mode.SAGEMAKER:
45+
raise CohereError("This method is only supported in SageMaker mode.")
3646

3747
def _does_endpoint_exist(self, endpoint_name: str) -> bool:
3848
try:
@@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
5060
Raises:
5161
CohereError: Connection to the endpoint failed.
5262
"""
63+
self._require_sagemaker()
5364
if not self._does_endpoint_exist(endpoint_name):
5465
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
5566
self._endpoint_name = endpoint_name
@@ -137,6 +148,7 @@ def create_endpoint(
137148
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
138149
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
139150
"""
151+
self._require_sagemaker()
140152
# First, check if endpoint already exists
141153
if self._does_endpoint_exist(endpoint_name):
142154
if recreate:
@@ -550,11 +562,15 @@ def embed(
550562
variant: Optional[str] = None,
551563
input_type: Optional[str] = None,
552564
model_id: Optional[str] = None,
553-
) -> Embeddings:
565+
output_dimension: Optional[int] = None,
566+
embedding_types: Optional[List[str]] = None,
567+
) -> Union[Embeddings, Dict[str, List]]:
554568
json_params = {
555569
'texts': texts,
556570
'truncate': truncate,
557-
"input_type": input_type
571+
"input_type": input_type,
572+
"output_dimension": output_dimension,
573+
"embedding_types": embedding_types,
558574
}
559575
for key, value in list(json_params.items()):
560576
if value is None:
@@ -591,7 +607,10 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
591607
# ValidationError, e.g. when variant is bad
592608
raise CohereError(str(e))
593609

594-
return Embeddings(response['embeddings'])
610+
embeddings = response['embeddings']
611+
if isinstance(embeddings, dict):
612+
return embeddings
613+
return Embeddings(embeddings)
595614

596615
def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
597616
if not model_id:
@@ -612,7 +631,10 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
612631
# ValidationError, e.g. when variant is bad
613632
raise CohereError(str(e))
614633

615-
return Embeddings(response['embeddings'])
634+
embeddings = response['embeddings']
635+
if isinstance(embeddings, dict):
636+
return embeddings
637+
return Embeddings(embeddings)
616638

617639

618640
def rerank(self,
@@ -805,6 +827,7 @@ def export_finetune(
805827
This should work when one uses the client inside SageMaker. If this errors out,
806828
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
807829
"""
830+
self._require_sagemaker()
808831
if name == "model":
809832
raise ValueError("name cannot be 'model'")
810833

@@ -948,6 +971,7 @@ def summarize(
948971
additional_command: Optional[str] = "",
949972
variant: Optional[str] = None
950973
) -> Summary:
974+
self._require_sagemaker()
951975

952976
if self._endpoint_name is None:
953977
raise CohereError("No endpoint connected. "
@@ -989,6 +1013,7 @@ def summarize(
9891013

9901014

9911015
def delete_endpoint(self) -> None:
1016+
self._require_sagemaker()
9921017
if self._endpoint_name is None:
9931018
raise CohereError("No endpoint connected.")
9941019
try:

0 commit comments

Comments
 (0)