Skip to content

Commit fc167c1

Browse files
authored
fix(oci): uppercase tool type field for OCI V2 API compatibility (#758)
* fix(oci): uppercase tool type field for OCI V2 API compatibility OCI Generative AI expects tool type as "FUNCTION" (uppercase) but the SDK passes through the Cohere format "function" (lowercase), causing a 400 error. Transform tool types to uppercase like we do for message roles and content types. * test(oci): add integration test for tool use on OCI on-demand The missing integration test allowed the tool type casing bug to ship undetected. This test calls OCI with a tool definition and verifies the response contains tool_calls with the correct function name and arguments. * fix(oci): complete casing audit for OCI field transformations Fix remaining casing issues found during systematic audit: - V1 tools: uppercase type field (same fix as V2) - tool_calls in messages: uppercase type when sending tool results back in multi-turn conversations - Response tool_calls: lowercase type from OCI's "FUNCTION" back to "function" for Cohere SDK compatibility - safety_mode: uppercase defensively (CONTEXTUAL/STRICT/OFF) Integration tests added for each: - test_chat_tool_use_response_type_lowered: verifies tool_call.type is "function" (not "FUNCTION") in responses - test_chat_multi_turn_tool_use_v2: full tool use round-trip (call → result → final response) - test_chat_safety_mode_v2: verifies safety_mode works on OCI * fix(oci): fix embed embedding_types casing and handle embeddingsByType response - embedding_types: OCI expects lowercase (float, int8) not uppercase. The .upper() was breaking all embedding_types requests. - Response: OCI returns "embeddingsByType" (not "embeddings") when embeddingTypes is specified. Handle both response keys. - Unit test updated to expect lowercase. - Integration tests added: embedding_types=["float"] and truncate modes. * fix(oci): guard safety_mode.upper() against None value Address Cursor Bugbot review: safety_mode is Optional, so the SDK can pass None when the user explicitly sets safety_mode=None. Guard with a None check before calling .upper() on both V1 and V2 paths.
1 parent b911acf commit fc167c1

2 files changed

Lines changed: 189 additions & 11 deletions

File tree

src/cohere/oci_client.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,8 @@ def transform_request_to_oci(
669669
oci_body["truncate"] = cohere_body["truncate"].upper()
670670

671671
if "embedding_types" in cohere_body:
672-
oci_body["embeddingTypes"] = [et.upper() for et in cohere_body["embedding_types"]]
672+
# OCI expects lowercase embedding types (float, int8, binary, etc.)
673+
oci_body["embeddingTypes"] = [et.lower() for et in cohere_body["embedding_types"]]
673674
if "max_tokens" in cohere_body:
674675
oci_body["maxTokens"] = cohere_body["max_tokens"]
675676
if "output_dimension" in cohere_body:
@@ -728,7 +729,13 @@ def transform_request_to_oci(
728729
oci_msg["content"] = msg.get("content") or []
729730

730731
if "tool_calls" in msg:
731-
oci_msg["toolCalls"] = msg["tool_calls"]
732+
oci_tool_calls = []
733+
for tc in msg["tool_calls"]:
734+
oci_tc = {**tc}
735+
if "type" in oci_tc:
736+
oci_tc["type"] = oci_tc["type"].upper()
737+
oci_tool_calls.append(oci_tc)
738+
oci_msg["toolCalls"] = oci_tool_calls
732739
if "tool_call_id" in msg:
733740
oci_msg["toolCallId"] = msg["tool_call_id"]
734741
if "tool_plan" in msg:
@@ -756,7 +763,13 @@ def transform_request_to_oci(
756763
if "stop_sequences" in cohere_body:
757764
chat_request["stopSequences"] = cohere_body["stop_sequences"]
758765
if "tools" in cohere_body:
759-
chat_request["tools"] = cohere_body["tools"]
766+
oci_tools = []
767+
for tool in cohere_body["tools"]:
768+
oci_tool = {**tool}
769+
if "type" in oci_tool:
770+
oci_tool["type"] = oci_tool["type"].upper()
771+
oci_tools.append(oci_tool)
772+
chat_request["tools"] = oci_tools
760773
if "strict_tools" in cohere_body:
761774
chat_request["strictTools"] = cohere_body["strict_tools"]
762775
if "documents" in cohere_body:
@@ -765,8 +778,8 @@ def transform_request_to_oci(
765778
chat_request["citationOptions"] = cohere_body["citation_options"]
766779
if "response_format" in cohere_body:
767780
chat_request["responseFormat"] = cohere_body["response_format"]
768-
if "safety_mode" in cohere_body:
769-
chat_request["safetyMode"] = cohere_body["safety_mode"]
781+
if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None:
782+
chat_request["safetyMode"] = cohere_body["safety_mode"].upper()
770783
if "logprobs" in cohere_body:
771784
chat_request["logprobs"] = cohere_body["logprobs"]
772785
if "tool_choice" in cohere_body:
@@ -810,13 +823,19 @@ def transform_request_to_oci(
810823
if "documents" in cohere_body:
811824
chat_request["documents"] = cohere_body["documents"]
812825
if "tools" in cohere_body:
813-
chat_request["tools"] = cohere_body["tools"]
826+
oci_tools = []
827+
for tool in cohere_body["tools"]:
828+
oci_tool = {**tool}
829+
if "type" in oci_tool:
830+
oci_tool["type"] = oci_tool["type"].upper()
831+
oci_tools.append(oci_tool)
832+
chat_request["tools"] = oci_tools
814833
if "tool_results" in cohere_body:
815834
chat_request["toolResults"] = cohere_body["tool_results"]
816835
if "response_format" in cohere_body:
817836
chat_request["responseFormat"] = cohere_body["response_format"]
818-
if "safety_mode" in cohere_body:
819-
chat_request["safetyMode"] = cohere_body["safety_mode"]
837+
if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None:
838+
chat_request["safetyMode"] = cohere_body["safety_mode"].upper()
820839
if "priority" in cohere_body:
821840
chat_request["priority"] = cohere_body["priority"]
822841

@@ -857,7 +876,8 @@ def transform_oci_response_to_cohere(
857876
Transformed response in Cohere format
858877
"""
859878
if endpoint == "embed":
860-
embeddings_data = oci_response.get("embeddings", {})
879+
# OCI returns "embeddings" by default, or "embeddingsByType" when embeddingTypes is specified
880+
embeddings_data = oci_response.get("embeddingsByType") or oci_response.get("embeddings", {})
861881

862882
if isinstance(embeddings_data, dict):
863883
normalized_embeddings = {str(key).lower(): value for key, value in embeddings_data.items()}
@@ -911,7 +931,12 @@ def transform_oci_response_to_cohere(
911931
message = {**message, "content": transformed_content}
912932

913933
if "toolCalls" in message:
914-
tool_calls = message["toolCalls"]
934+
tool_calls = []
935+
for tc in message["toolCalls"]:
936+
lowered_tc = {**tc}
937+
if "type" in lowered_tc:
938+
lowered_tc["type"] = lowered_tc["type"].lower()
939+
tool_calls.append(lowered_tc)
915940
message = {k: v for k, v in message.items() if k != "toolCalls"}
916941
message["tool_calls"] = tool_calls
917942
if "toolPlan" in message:

tests/test_oci_client.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,135 @@ def test_chat_v2(self):
190190
self.assertIsNotNone(response)
191191
self.assertIsNotNone(response.message)
192192

193+
def test_chat_tool_use_v2(self):
194+
"""Test tool use with v2 client on OCI on-demand inference."""
195+
response = self.client.chat(
196+
model="command-a-03-2025",
197+
messages=[{"role": "user", "content": "What's the weather in Toronto?"}],
198+
max_tokens=200,
199+
tools=[{
200+
"type": "function",
201+
"function": {
202+
"name": "get_weather",
203+
"description": "Get current weather for a location",
204+
"parameters": {
205+
"type": "object",
206+
"properties": {
207+
"location": {"type": "string", "description": "City name"}
208+
},
209+
"required": ["location"],
210+
},
211+
},
212+
}],
213+
)
214+
215+
self.assertIsNotNone(response)
216+
self.assertIsNotNone(response.message)
217+
self.assertEqual(response.finish_reason, "TOOL_CALL")
218+
self.assertTrue(len(response.message.tool_calls) > 0)
219+
tool_call = response.message.tool_calls[0]
220+
self.assertEqual(tool_call.function.name, "get_weather")
221+
self.assertIn("Toronto", tool_call.function.arguments)
222+
223+
def test_chat_tool_use_response_type_lowered(self):
224+
"""Test that tool_call type is lowercased in response (OCI returns FUNCTION)."""
225+
response = self.client.chat(
226+
model="command-a-03-2025",
227+
messages=[{"role": "user", "content": "What's the weather in Toronto?"}],
228+
max_tokens=200,
229+
tools=[{
230+
"type": "function",
231+
"function": {
232+
"name": "get_weather",
233+
"description": "Get current weather for a location",
234+
"parameters": {
235+
"type": "object",
236+
"properties": {
237+
"location": {"type": "string", "description": "City name"}
238+
},
239+
"required": ["location"],
240+
},
241+
},
242+
}],
243+
)
244+
245+
self.assertEqual(response.finish_reason, "TOOL_CALL")
246+
tool_call = response.message.tool_calls[0]
247+
# OCI returns "FUNCTION" — SDK must lowercase to "function" for Cohere compat
248+
self.assertEqual(tool_call.type, "function")
249+
250+
def test_chat_multi_turn_tool_use_v2(self):
251+
"""Test multi-turn tool use: send tool result back after tool call."""
252+
# Step 1: Get a tool call
253+
response = self.client.chat(
254+
model="command-a-03-2025",
255+
messages=[{"role": "user", "content": "What's the weather in Toronto?"}],
256+
max_tokens=200,
257+
tools=[{
258+
"type": "function",
259+
"function": {
260+
"name": "get_weather",
261+
"description": "Get current weather for a location",
262+
"parameters": {
263+
"type": "object",
264+
"properties": {
265+
"location": {"type": "string", "description": "City name"}
266+
},
267+
"required": ["location"],
268+
},
269+
},
270+
}],
271+
)
272+
self.assertEqual(response.finish_reason, "TOOL_CALL")
273+
tool_call = response.message.tool_calls[0]
274+
275+
# Step 2: Send tool result back
276+
response2 = self.client.chat(
277+
model="command-a-03-2025",
278+
messages=[
279+
{"role": "user", "content": "What's the weather in Toronto?"},
280+
{
281+
"role": "assistant",
282+
"tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": "get_weather", "arguments": tool_call.function.arguments}}],
283+
"tool_plan": response.message.tool_plan,
284+
},
285+
{
286+
"role": "tool",
287+
"tool_call_id": tool_call.id,
288+
"content": [{"type": "text", "text": "15°C, sunny"}],
289+
},
290+
],
291+
max_tokens=200,
292+
tools=[{
293+
"type": "function",
294+
"function": {
295+
"name": "get_weather",
296+
"description": "Get current weather for a location",
297+
"parameters": {
298+
"type": "object",
299+
"properties": {
300+
"location": {"type": "string", "description": "City name"}
301+
},
302+
"required": ["location"],
303+
},
304+
},
305+
}],
306+
)
307+
308+
self.assertIsNotNone(response2.message)
309+
# Model should respond with text incorporating the tool result
310+
self.assertTrue(len(response2.message.content) > 0)
311+
312+
def test_chat_safety_mode_v2(self):
313+
"""Test that safety_mode is uppercased for OCI."""
314+
# Cohere SDK enum values are already uppercase, but test lowercase too
315+
response = self.client.chat(
316+
model="command-a-03-2025",
317+
messages=[{"role": "user", "content": "Say hi"}],
318+
safety_mode="STRICT",
319+
)
320+
self.assertIsNotNone(response.message)
321+
193322
def test_chat_stream_v2(self):
194323
"""Test V2 streaming chat terminates and produces correct event lifecycle."""
195324
events = []
@@ -389,6 +518,30 @@ def test_embed_search_query_input_type(self):
389518
self.assertIsNotNone(response.embeddings.float_)
390519
self.assertEqual(len(response.embeddings.float_[0]), 1024)
391520

521+
def test_embed_with_embedding_types(self):
522+
"""Test embed with explicit embedding_types parameter."""
523+
response = self.client.embed(
524+
model="embed-english-v3.0",
525+
texts=["Hello world"],
526+
input_type="search_document",
527+
embedding_types=["float"],
528+
)
529+
self.assertIsNotNone(response.embeddings.float_)
530+
self.assertEqual(len(response.embeddings.float_[0]), 1024)
531+
532+
def test_embed_with_truncate(self):
533+
"""Test embed with truncate parameter."""
534+
long_text = "hello " * 1000
535+
for mode in ["NONE", "START", "END"]:
536+
response = self.client.embed(
537+
model="embed-english-v3.0",
538+
texts=[long_text],
539+
input_type="search_document",
540+
truncate=mode,
541+
)
542+
self.assertIsNotNone(response.embeddings.float_)
543+
self.assertEqual(len(response.embeddings.float_[0]), 1024)
544+
392545
def test_command_r_plus_chat(self):
393546
"""Test command-r-plus-08-2024 via V1 client."""
394547
v1_client = cohere.OciClient(
@@ -652,7 +805,7 @@ def test_transform_embed_request(self):
652805
self.assertEqual(result["inputs"], ["hello", "world"])
653806
self.assertEqual(result["inputType"], "SEARCH_DOCUMENT")
654807
self.assertEqual(result["truncate"], "END")
655-
self.assertEqual(result["embeddingTypes"], ["FLOAT", "INT8"])
808+
self.assertEqual(result["embeddingTypes"], ["float", "int8"])
656809
self.assertEqual(result["compartmentId"], "compartment-123")
657810
self.assertEqual(result["servingMode"]["modelId"], "cohere.embed-english-v3.0")
658811

0 commit comments

Comments
 (0)