Skip to content

Commit eb0771d

Browse files
fix: detect and raise SSE stream error events instead of silently discarding them (#144) (#192)
Add ClassVar-driven error detection to the base Stream class. Mid-stream error events from providers (e.g. Anthropic overloaded_error, OpenAI server_error) now raise StreamEventError instead of being silently skipped. - Add _error_type_fields ClassVar and _build_error_from_value helper to Stream - Add _parse_stream_error base implementation handling type-based and field-based SSE error patterns - Add StreamEventError exception with error_type, event_data, provider attributes - Add raise_for_status() to stream_post() for HTTP-level errors - Provider overrides use ClassVar only (Google GenerateContent) or ClassVar + helper (Google Interactions, OpenResponses) - Remove dead SSE_EVENT_ERROR constant from Anthropic config
1 parent 3ef3f9d commit eb0771d

9 files changed

Lines changed: 256 additions & 7 deletions

File tree

src/celeste/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MissingCredentialsError,
2525
ModelNotFoundError,
2626
StreamEmptyError,
27+
StreamEventError,
2728
StreamingNotSupportedError,
2829
StreamNotExhaustedError,
2930
UnsupportedCapabilityError,
@@ -264,6 +265,7 @@ def create_client(
264265
"RefResolvingJsonSchemaGenerator",
265266
"Role",
266267
"StreamEmptyError",
268+
"StreamEventError",
267269
"StreamNotExhaustedError",
268270
"StreamingNotSupportedError",
269271
"StrictJsonSchemaGenerator",

src/celeste/exceptions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Custom exceptions for Celeste."""
22

3+
from typing import Any
4+
35

46
class Error(Exception):
57
"""Base exception for all Celeste errors."""
@@ -170,6 +172,27 @@ def __init__(self) -> None:
170172
super().__init__("Stream completed but no chunks were produced")
171173

172174

175+
class StreamEventError(StreamingError):
176+
"""Raised when the provider sends an error event during streaming."""
177+
178+
def __init__(
179+
self,
180+
message: str,
181+
*,
182+
error_type: str | None = None,
183+
event_data: dict[str, Any] | None = None,
184+
provider: str | None = None,
185+
) -> None:
186+
"""Initialize with error details from the stream event."""
187+
self.error_type = error_type
188+
self.event_data = event_data or {}
189+
self.provider = provider
190+
191+
prefix = f"{provider} stream error" if provider else "Stream error"
192+
suffix = f" [{error_type}]" if error_type else ""
193+
super().__init__(f"{prefix}{suffix}: {message}")
194+
195+
173196
class MissingDependencyError(Error):
174197
"""Raised when a required optional dependency is not installed."""
175198

@@ -231,6 +254,7 @@ def __init__(self, parameter: str, model_id: str) -> None:
231254
"ModalityNotFoundError",
232255
"ModelNotFoundError",
233256
"StreamEmptyError",
257+
"StreamEventError",
234258
"StreamNotExhaustedError",
235259
"StreamingNotSupportedError",
236260
"UnsupportedCapabilityError",

src/celeste/http.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ async def stream_post(
185185
headers=headers,
186186
timeout=timeout,
187187
) as event_source:
188+
event_source.response.raise_for_status()
188189
async for sse in event_source.aiter_sse():
189190
try:
190191
yield json.loads(sse.data)

src/celeste/protocols/openresponses/streaming.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""OpenResponses protocol SSE parsing for streaming."""
22

3-
from typing import Any
3+
from typing import Any, ClassVar
44

55
from celeste.io import FinishReason
66

@@ -22,6 +22,14 @@ class OpenResponsesStream:
2222
Modality streams call super() methods which resolve to this via MRO.
2323
"""
2424

25+
_error_type_fields: ClassVar[tuple[str, ...]] = ("code",)
26+
27+
def _parse_stream_error(self, event_data: dict[str, Any]) -> dict[str, Any] | None:
28+
"""Detect Responses API error events (flat shape: code/message at root level)."""
29+
if event_data.get("type") == "error":
30+
return self._build_error_from_value(event_data) # type: ignore[attr-defined, no-any-return]
31+
return None
32+
2533
def _parse_chunk_content(self, event_data: dict[str, Any]) -> str | None:
2634
"""Extract content from SSE event."""
2735
event_type = event_data.get("type")

src/celeste/providers/anthropic/messages/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,3 @@ class VertexAnthropicEndpoint(StrEnum):
4848
SSE_EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
4949
SSE_EVENT_MESSAGE_DELTA = "message_delta"
5050
SSE_EVENT_MESSAGE_STOP = "message_stop"
51-
SSE_EVENT_ERROR = "error"

src/celeste/providers/google/generate_content/streaming.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Google GenerateContent SSE parsing for streaming."""
22

3-
from typing import Any
3+
from typing import Any, ClassVar
44

55
from celeste.io import FinishReason
66

@@ -18,6 +18,8 @@ class GoogleGenerateContentStream:
1818
Modality streams call super() methods which resolve to this via MRO.
1919
"""
2020

21+
_error_type_fields: ClassVar[tuple[str, ...]] = ("status", "code")
22+
2123
def _parse_chunk_content(self, event_data: dict[str, Any]) -> str | None:
2224
"""Extract content from SSE event."""
2325
candidates = event_data.get("candidates", [])

src/celeste/providers/google/interactions/streaming.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Google Interactions SSE parsing for streaming."""
22

3-
from typing import Any
3+
from typing import Any, ClassVar
44

55
from celeste.io import FinishReason
66

@@ -22,6 +22,15 @@ class GoogleInteractionsStream:
2222
Modality streams call super() methods which resolve to this via MRO.
2323
"""
2424

25+
_error_type_fields: ClassVar[tuple[str, ...]] = ("status", "code")
26+
27+
def _parse_stream_error(self, event_data: dict[str, Any]) -> dict[str, Any] | None:
28+
"""Detect Interactions error events (dual event_type/type field)."""
29+
event_type = event_data.get("event_type") or event_data.get("type")
30+
if event_type == "error":
31+
return self._build_error_from_value(event_data.get("error", {}))
32+
return None
33+
2534
def _parse_chunk_content(self, event_data: dict[str, Any]) -> str | None:
2635
"""Extract content from SSE event.
2736

src/celeste/streaming.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from anyio.from_thread import BlockingPortal, start_blocking_portal
1010

11-
from celeste.exceptions import StreamNotExhaustedError
11+
from celeste.exceptions import StreamEventError, StreamNotExhaustedError
1212
from celeste.io import Chunk as ChunkBase
1313
from celeste.io import FinishReason, Output, Usage
1414
from celeste.parameters import Parameters
@@ -31,6 +31,7 @@ class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC):
3131
_chunk_class: ClassVar[type[ChunkBase]]
3232
_output_class: ClassVar[type[Output]]
3333
_empty_content: ClassVar[Any]
34+
_error_type_fields: ClassVar[tuple[str, ...]] = ("type", "code")
3435

3536
def __init__(
3637
self,
@@ -51,6 +52,45 @@ def __init__(
5152
self._portal: BlockingPortal | None = None
5253
self._portal_cm: AbstractContextManager[BlockingPortal] | None = None
5354

55+
def _build_error_from_value(self, error: Any) -> dict[str, Any]: # noqa: ANN401
56+
"""Extract {type, message} from an error value using _error_type_fields."""
57+
if isinstance(error, dict):
58+
error_type = None
59+
for field in self._error_type_fields:
60+
val = error.get(field)
61+
if val is not None:
62+
error_type = str(val)
63+
break
64+
return {
65+
"type": error_type,
66+
"message": error.get("message", "Unknown error"),
67+
}
68+
# Non-dict error value (e.g., plain string from Cohere)
69+
return {"message": str(error) if error else "Unknown error"}
70+
71+
def _parse_stream_error(self, event_data: dict[str, Any]) -> dict[str, Any] | None:
72+
"""Detect error events in the SSE stream.
73+
74+
Handles two generic SSE error patterns:
75+
1. Type-based: {"type": "error", "error": {"type": "...", "message": "..."}}
76+
2. Field-based: {"error": {"message": "...", "type": "..."}}
77+
78+
Override in provider mixin for non-standard error shapes.
79+
"""
80+
error = None
81+
82+
# Pattern 1: Type-based — event has "type": "error" with nested error
83+
if event_data.get("type") == "error":
84+
error = event_data.get("error", {})
85+
# Pattern 2: Field-based — event has top-level "error" dict
86+
elif isinstance(event_data.get("error"), dict):
87+
error = event_data["error"]
88+
89+
if error is None:
90+
return None
91+
92+
return self._build_error_from_value(error)
93+
5494
def _parse_chunk_content(self, event_data: dict[str, Any]) -> Any | None: # noqa: ANN401
5595
"""Parse content from chunk event. Override in provider mixin."""
5696
return None
@@ -61,6 +101,14 @@ def _wrap_chunk_content(self, raw_content: Any) -> Any: # noqa: ANN401
61101

62102
def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
63103
"""Parse SSE event into Chunk (returns None to filter lifecycle events)."""
104+
error = self._parse_stream_error(event)
105+
if error is not None:
106+
raise StreamEventError(
107+
message=error.get("message", "Unknown stream error"),
108+
error_type=error.get("type"),
109+
event_data=event,
110+
provider=self._stream_metadata.get("provider"),
111+
)
64112
content = self._parse_chunk_content(event)
65113
usage = self._get_chunk_usage(event)
66114
finish_reason = self._get_chunk_finish_reason(event)

tests/unit_tests/test_streaming.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""High-value tests for Stream - focusing on lifecycle, resource cleanup, and state management."""
22

33
from collections.abc import AsyncIterator
4-
from typing import Any, Unpack
4+
from typing import Any, ClassVar, Unpack
55
from unittest.mock import AsyncMock
66

77
import pytest
88
from pydantic import Field
99

10-
from celeste.exceptions import StreamNotExhaustedError
10+
from celeste.exceptions import StreamEventError, StreamNotExhaustedError
1111
from celeste.io import Chunk, FinishReason, Output, Usage
1212
from celeste.parameters import Parameters
1313
from celeste.streaming import Stream
@@ -720,3 +720,159 @@ def _parse_output( # type: ignore[override]
720720
assert output.finish_reason is not None
721721
assert isinstance(output.finish_reason, TypedFinishReason)
722722
assert output.finish_reason.reason == "stop"
723+
724+
725+
class PipelineStream(Stream[ConcreteOutput, Parameters, Chunk]):
726+
"""Stream that uses the base _parse_chunk pipeline (for testing error detection).
727+
728+
Unlike ConcreteStream which overrides _parse_chunk entirely, this class
729+
only overrides _parse_chunk_content and _aggregate_content, so the base
730+
_parse_stream_error → StreamEventError pipeline is exercised.
731+
"""
732+
733+
_chunk_class: ClassVar[type[Chunk]] = Chunk
734+
_output_class: ClassVar[type[Output]] = ConcreteOutput
735+
_empty_content: ClassVar[str] = ""
736+
737+
def _aggregate_content(self, chunks: list[Chunk]) -> str:
738+
"""Aggregate content from chunks."""
739+
return "".join(str(chunk.content) for chunk in chunks)
740+
741+
def _parse_chunk_content(self, event_data: dict[str, Any]) -> str | None:
742+
"""Extract content from delta field."""
743+
return event_data.get("delta") or None
744+
745+
746+
class TestStreamErrorDetection:
747+
"""Test Stream error detection via base _parse_stream_error pipeline."""
748+
749+
async def test_type_based_error_raises_stream_event_error(self) -> None:
750+
"""Type-based error pattern (Anthropic) must raise StreamEventError."""
751+
events = [
752+
{
753+
"type": "error",
754+
"error": {"type": "overloaded_error", "message": "Server overloaded"},
755+
},
756+
]
757+
stream = PipelineStream(
758+
_async_iter(events),
759+
stream_metadata={"provider": "anthropic"},
760+
)
761+
with pytest.raises(StreamEventError, match="Server overloaded") as exc_info:
762+
async for _ in stream:
763+
pass
764+
assert exc_info.value.error_type == "overloaded_error"
765+
assert exc_info.value.provider == "anthropic"
766+
assert exc_info.value.event_data == events[0]
767+
768+
async def test_field_based_error_raises_stream_event_error(self) -> None:
769+
"""Field-based error pattern (ChatCompletions) must raise StreamEventError."""
770+
events = [
771+
{"error": {"type": "invalid_request", "message": "Bad request"}},
772+
]
773+
stream = PipelineStream(
774+
_async_iter(events),
775+
stream_metadata={"provider": "openai"},
776+
)
777+
with pytest.raises(StreamEventError, match="Bad request") as exc_info:
778+
async for _ in stream:
779+
pass
780+
assert exc_info.value.error_type == "invalid_request"
781+
assert exc_info.value.provider == "openai"
782+
783+
async def test_field_based_error_falls_back_to_code_field(self) -> None:
784+
"""Field-based error without 'type' must fall back to 'code' field."""
785+
events = [
786+
{"error": {"code": "rate_limit_exceeded", "message": "Rate limited"}},
787+
]
788+
stream = PipelineStream(_async_iter(events))
789+
with pytest.raises(StreamEventError) as exc_info:
790+
async for _ in stream:
791+
pass
792+
assert exc_info.value.error_type == "rate_limit_exceeded"
793+
794+
async def test_type_based_error_with_string_error_value(self) -> None:
795+
"""Type-based error with non-dict error value must use string fallback."""
796+
events = [
797+
{"type": "error", "error": "Something went wrong"},
798+
]
799+
stream = PipelineStream(_async_iter(events))
800+
with pytest.raises(StreamEventError, match="Something went wrong") as exc_info:
801+
async for _ in stream:
802+
pass
803+
assert exc_info.value.error_type is None
804+
805+
async def test_error_type_fields_classvar_override(self) -> None:
806+
"""ClassVar override of _error_type_fields must change field lookup order."""
807+
808+
class GoogleLikeStream(PipelineStream):
809+
_error_type_fields: ClassVar[tuple[str, ...]] = ("status", "code")
810+
811+
events = [
812+
{
813+
"error": {
814+
"status": "PERMISSION_DENIED",
815+
"code": 403,
816+
"message": "Forbidden",
817+
},
818+
},
819+
]
820+
stream = GoogleLikeStream(_async_iter(events))
821+
with pytest.raises(StreamEventError) as exc_info:
822+
async for _ in stream:
823+
pass
824+
assert exc_info.value.error_type == "PERMISSION_DENIED"
825+
826+
async def test_non_error_events_pass_through(self) -> None:
827+
"""Normal events must not trigger error detection."""
828+
events = [
829+
{"delta": "Hello"},
830+
{"delta": " world"},
831+
]
832+
stream = PipelineStream(_async_iter(events))
833+
chunks = [chunk async for chunk in stream]
834+
assert len(chunks) == 2
835+
assert stream.output.content == "Hello world"
836+
837+
async def test_error_after_successful_chunks(self) -> None:
838+
"""Error mid-stream (after successful chunks) must raise StreamEventError."""
839+
events = [
840+
{"delta": "Hello"},
841+
{
842+
"type": "error",
843+
"error": {"type": "server_error", "message": "Internal error"},
844+
},
845+
]
846+
stream = PipelineStream(
847+
_async_iter(events),
848+
stream_metadata={"provider": "test"},
849+
)
850+
chunks: list[Chunk] = []
851+
with pytest.raises(StreamEventError, match="Internal error"):
852+
async for chunk in stream:
853+
chunks.append(chunk)
854+
assert len(chunks) == 1
855+
assert chunks[0].content == "Hello"
856+
857+
async def test_error_with_no_message_uses_default(self) -> None:
858+
"""Error event without message field must use 'Unknown error' default."""
859+
events = [
860+
{"error": {"type": "mystery_error"}},
861+
]
862+
stream = PipelineStream(_async_iter(events))
863+
with pytest.raises(StreamEventError, match="Unknown error"):
864+
async for _ in stream:
865+
pass
866+
867+
async def test_error_provides_full_event_data(self) -> None:
868+
"""StreamEventError must include the full original event data."""
869+
event = {
870+
"type": "error",
871+
"error": {"type": "api_error", "message": "Fail"},
872+
"extra": "data",
873+
}
874+
stream = PipelineStream(_async_iter([event]))
875+
with pytest.raises(StreamEventError) as exc_info:
876+
async for _ in stream:
877+
pass
878+
assert exc_info.value.event_data == event

0 commit comments

Comments
 (0)