Skip to content

Commit b911acf

Browse files
authored
fix(oci): terminate stream on finishReason — OCI does not send [DONE] (#757)
* fix(oci): terminate stream on finishReason instead of waiting for [DONE] OCI Generative AI does not send a `data: [DONE]` SSE marker to signal end-of-stream. It sends a final event with `finishReason` and keeps the connection open, causing chat_stream() to hang indefinitely. Emit closing events (message-end / stream-end) and return from the generator when `finishReason` is detected. The [DONE] path is kept as a fallback for forward compatibility. Fixes #756 * test(oci): assert stream termination and full event lifecycle Strengthen V1 and V2 streaming integration tests to verify streams terminate correctly and produce the expected event sequence: - V1: stream-start → text-generation(s) → stream-end - V2: message-start → content-start → content-delta(s) → content-end → message-end Without these assertions the previous tests would have hung forever on the streaming bug rather than failing with a clear error.
1 parent 2598c9a commit b911acf

2 files changed

Lines changed: 58 additions & 32 deletions

File tree

src/cohere/oci_client.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,36 +1055,45 @@ def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10551055
final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason)
10561056
yield _emit_v1_event(event)
10571057

1058+
stream_finished = False
1059+
1060+
def _emit_closing_events() -> typing.Iterator[bytes]:
1061+
"""Emit the final closing events for the stream."""
1062+
if is_v2:
1063+
if emitted_start:
1064+
if not emitted_content_end:
1065+
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1066+
message_end_event: typing.Dict[str, typing.Any] = {
1067+
"type": "message-end",
1068+
"id": generation_id,
1069+
"delta": {"finish_reason": final_finish_reason},
1070+
}
1071+
if final_usage:
1072+
message_end_event["delta"]["usage"] = final_usage
1073+
yield _emit_v2_event(message_end_event)
1074+
else:
1075+
yield _emit_v1_event(
1076+
{
1077+
"event_type": "stream-end",
1078+
"finish_reason": final_v1_finish_reason,
1079+
"response": {
1080+
"text": full_v1_text,
1081+
"generation_id": generation_id,
1082+
"finish_reason": final_v1_finish_reason,
1083+
},
1084+
}
1085+
)
1086+
10581087
def _process_line(line: str) -> typing.Iterator[bytes]:
1088+
nonlocal stream_finished
10591089
if not line.startswith("data: "):
10601090
return
10611091

10621092
data_str = line[6:]
10631093
if data_str.strip() == "[DONE]":
1064-
if is_v2:
1065-
if emitted_start:
1066-
if not emitted_content_end:
1067-
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1068-
message_end_event: typing.Dict[str, typing.Any] = {
1069-
"type": "message-end",
1070-
"id": generation_id,
1071-
"delta": {"finish_reason": final_finish_reason},
1072-
}
1073-
if final_usage:
1074-
message_end_event["delta"]["usage"] = final_usage
1075-
yield _emit_v2_event(message_end_event)
1076-
else:
1077-
yield _emit_v1_event(
1078-
{
1079-
"event_type": "stream-end",
1080-
"finish_reason": final_v1_finish_reason,
1081-
"response": {
1082-
"text": full_v1_text,
1083-
"generation_id": generation_id,
1084-
"finish_reason": final_v1_finish_reason,
1085-
},
1086-
}
1087-
)
1094+
for event_bytes in _emit_closing_events():
1095+
yield event_bytes
1096+
stream_finished = True
10881097
return
10891098

10901099
try:
@@ -1102,15 +1111,23 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
11021111
except Exception as exc:
11031112
raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc
11041113

1114+
# OCI may not send [DONE] — treat finishReason as stream termination
1115+
if "finishReason" in oci_event:
1116+
for event_bytes in _emit_closing_events():
1117+
yield event_bytes
1118+
stream_finished = True
1119+
11051120
for chunk in stream:
11061121
buffer += chunk
11071122
while b"\n" in buffer:
11081123
line_bytes, buffer = buffer.split(b"\n", 1)
11091124
line = line_bytes.decode("utf-8").strip()
11101125
for event_bytes in _process_line(line):
11111126
yield event_bytes
1127+
if stream_finished:
1128+
return
11121129

1113-
if buffer.strip():
1130+
if buffer.strip() and not stream_finished:
11141131
line = buffer.decode("utf-8").strip()
11151132
for event_bytes in _process_line(line):
11161133
yield event_bytes

tests/test_oci_client.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_chat(self):
113113
self.assertIn("4", response.text)
114114

115115
def test_chat_stream(self):
116-
"""Test V1 streaming chat."""
116+
"""Test V1 streaming chat terminates and produces correct events."""
117117
events = []
118118
for event in self.client.chat_stream(
119119
model="command-r-08-2024",
@@ -125,6 +125,11 @@ def test_chat_stream(self):
125125
text_events = [e for e in events if hasattr(e, "text") and e.text]
126126
self.assertTrue(len(text_events) > 0)
127127

128+
# Verify stream terminates with correct event lifecycle
129+
event_types = [getattr(e, "event_type", None) for e in events]
130+
self.assertEqual(event_types[0], "stream-start")
131+
self.assertEqual(event_types[-1], "stream-end")
132+
128133

129134
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")
130135
class TestOciClientV2(unittest.TestCase):
@@ -186,7 +191,7 @@ def test_chat_v2(self):
186191
self.assertIsNotNone(response.message)
187192

188193
def test_chat_stream_v2(self):
189-
"""Test streaming chat with v2 client."""
194+
"""Test V2 streaming chat terminates and produces correct event lifecycle."""
190195
events = []
191196
for event in self.client.chat_stream(
192197
model="command-a-03-2025",
@@ -195,11 +200,16 @@ def test_chat_stream_v2(self):
195200
events.append(event)
196201

197202
self.assertTrue(len(events) > 0)
198-
# Verify we received content-delta events with text
199-
content_delta_events = [e for e in events if hasattr(e, "type") and e.type == "content-delta"]
200-
self.assertTrue(len(content_delta_events) > 0)
201203

202-
# Verify we can extract text from events
204+
# Verify full event lifecycle: message-start → content-start → content-delta(s) → content-end → message-end
205+
event_types = [e.type for e in events]
206+
self.assertEqual(event_types[0], "message-start")
207+
self.assertIn("content-start", event_types)
208+
self.assertIn("content-delta", event_types)
209+
self.assertIn("content-end", event_types)
210+
self.assertEqual(event_types[-1], "message-end")
211+
212+
# Verify we can extract text from content-delta events
203213
full_text = ""
204214
for event in events:
205215
if (
@@ -214,7 +224,6 @@ def test_chat_stream_v2(self):
214224
):
215225
full_text += event.delta.message.content.text
216226

217-
# Should have received some text
218227
self.assertTrue(len(full_text) > 0)
219228

220229
@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set")

0 commit comments

Comments
 (0)