Skip to content

Commit 7d10721

Browse files
fix: enrich streaming HTTP errors with provider-specific messages (#196)
Streaming errors now produce the same enriched messages as non-streaming errors (e.g. "Anthropic API error: Invalid API Key" instead of the generic "Client error '401 Unauthorized' for url '...'"). Three-layer fix: - http.py: read response body before raise_for_status() so JSON error details are available downstream - streaming.py: add enrich_stream_errors() wrapper that catches HTTPStatusError and delegates to provider error handler - client.py: wire wrapper into _stream() covering all providers Closes #194
1 parent 10133e6 commit 7d10721

5 files changed

Lines changed: 183 additions & 5 deletions

File tree

src/celeste/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from celeste.mime_types import ApplicationMimeType
1717
from celeste.models import Model
1818
from celeste.parameters import ParameterMapper, Parameters
19-
from celeste.streaming import Stream
19+
from celeste.streaming import Stream, enrich_stream_errors
2020
from celeste.types import RawUsage
2121

2222

@@ -250,6 +250,7 @@ def _stream(
250250
extra_headers=extra_headers,
251251
**parameters,
252252
)
253+
sse_iterator = enrich_stream_errors(sse_iterator, self._handle_error_response)
253254
return stream_class(
254255
sse_iterator,
255256
transform_output=self._transform_output,

src/celeste/http.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ async def stream_post(
185185
headers=headers,
186186
timeout=timeout,
187187
) as event_source:
188-
event_source.response.raise_for_status()
188+
if not event_source.response.is_success:
189+
await event_source.response.aread()
190+
event_source.response.raise_for_status()
189191
async for sse in event_source.aiter_sse():
190192
try:
191193
yield json.loads(sse.data)
@@ -221,7 +223,9 @@ async def stream_post_ndjson(
221223
headers=headers,
222224
timeout=timeout,
223225
) as response:
224-
response.raise_for_status()
226+
if not response.is_success:
227+
await response.aread()
228+
response.raise_for_status()
225229
async for line in response.aiter_lines():
226230
if line:
227231
yield json.loads(line)

src/celeste/streaming.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from types import TracebackType
77
from typing import Any, ClassVar, Self, Unpack
88

9+
import httpx
910
from anyio.from_thread import BlockingPortal, start_blocking_portal
1011

1112
from celeste.exceptions import StreamEventError, StreamNotExhaustedError
@@ -15,6 +16,19 @@
1516
from celeste.types import RawUsage
1617

1718

19+
async def enrich_stream_errors(
20+
iterator: AsyncIterator[dict[str, Any]],
21+
error_handler: Callable[[httpx.Response], None],
22+
) -> AsyncIterator[dict[str, Any]]:
23+
"""Wrap stream iterator to enrich HTTP errors with provider-specific messages."""
24+
try:
25+
async for event in iterator:
26+
yield event
27+
except httpx.HTTPStatusError as e:
28+
error_handler(e.response)
29+
raise # Unreachable — error_handler always raises for error responses
30+
31+
1832
class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC):
1933
"""Async iterator wrapper providing final Output access after stream exhaustion.
2034
@@ -332,4 +346,4 @@ async def aclose(self) -> None:
332346
await self._sse_iterator.aclose()
333347

334348

335-
__all__ = ["Stream"]
349+
__all__ = ["Stream", "enrich_stream_errors"]

tests/unit_tests/test_http.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,83 @@ async def test_stream_post_passes_parameters_correctly(
823823
timeout=timeout,
824824
)
825825

826+
async def test_stream_post_raises_http_error_with_readable_body(
827+
self, mock_httpx_client: AsyncMock
828+
) -> None:
829+
"""stream_post reads response body before raising on HTTP errors."""
830+
# Arrange
831+
http_client = HTTPClient()
832+
mock_response = httpx.Response(
833+
401,
834+
content=b'{"error": {"message": "Invalid API Key"}}',
835+
request=httpx.Request("POST", "https://api.example.com/stream"),
836+
)
837+
838+
mock_source = MagicMock()
839+
mock_source.response = mock_response
840+
mock_source.__aenter__ = AsyncMock(return_value=mock_source)
841+
mock_source.__aexit__ = AsyncMock(return_value=False)
842+
843+
# Act & Assert
844+
with (
845+
patch("celeste.http.httpx.AsyncClient", return_value=mock_httpx_client),
846+
patch("celeste.http.aconnect_sse", return_value=mock_source),
847+
pytest.raises(httpx.HTTPStatusError) as exc_info,
848+
):
849+
async for _ in http_client.stream_post(
850+
url="https://api.example.com/stream",
851+
headers={"Authorization": "Bearer bad-key"},
852+
json_body={"prompt": "test"},
853+
):
854+
pass
855+
856+
# Body should be readable for downstream enrichment
857+
assert exc_info.value.response.json()["error"]["message"] == "Invalid API Key"
858+
859+
async def test_stream_post_ndjson_raises_http_error_with_readable_body(
860+
self, mock_httpx_client: AsyncMock
861+
) -> None:
862+
"""stream_post_ndjson reads response body before raising on HTTP errors."""
863+
# Arrange
864+
http_client = HTTPClient()
865+
error_body = b'{"error": {"message": "Forbidden"}}'
866+
mock_response = httpx.Response(
867+
403,
868+
content=error_body,
869+
request=httpx.Request("POST", "https://api.example.com/stream"),
870+
)
871+
mock_httpx_client.stream = MagicMock(return_value=_async_context(mock_response))
872+
873+
# Act & Assert
874+
with (
875+
patch("celeste.http.httpx.AsyncClient", return_value=mock_httpx_client),
876+
pytest.raises(httpx.HTTPStatusError) as exc_info,
877+
):
878+
async for _ in http_client.stream_post_ndjson(
879+
url="https://api.example.com/stream",
880+
headers={"Authorization": "Bearer bad-key"},
881+
json_body={"prompt": "test"},
882+
):
883+
pass
884+
885+
# Body should be readable for downstream enrichment
886+
assert exc_info.value.response.json()["error"]["message"] == "Forbidden"
887+
826888
@staticmethod
827889
async def _async_iter(items: list) -> AsyncIterator:
828890
"""Helper to create async iterator from list."""
829891
for item in items:
830892
yield item
893+
894+
895+
class _async_context:
896+
"""Async context manager wrapping a response for client.stream() mocking."""
897+
898+
def __init__(self, response: httpx.Response) -> None:
899+
self._response = response
900+
901+
async def __aenter__(self) -> httpx.Response:
902+
return self._response
903+
904+
async def __aexit__(self, *args: object) -> None:
905+
pass

tests/unit_tests/test_streaming.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from typing import Any, ClassVar, Unpack
55
from unittest.mock import AsyncMock
66

7+
import httpx
78
import pytest
89
from pydantic import Field
910

1011
from celeste.exceptions import StreamEventError, StreamNotExhaustedError
1112
from celeste.io import Chunk, FinishReason, Output, Usage
1213
from celeste.parameters import Parameters
13-
from celeste.streaming import Stream
14+
from celeste.streaming import Stream, enrich_stream_errors
1415

1516

1617
class ConcreteOutput(Output[str]):
@@ -876,3 +877,86 @@ async def test_error_provides_full_event_data(self) -> None:
876877
async for _ in stream:
877878
pass
878879
assert exc_info.value.event_data == event
880+
881+
882+
class TestEnrichStreamErrors:
883+
"""Test enrich_stream_errors wraps streaming HTTP errors with provider messages."""
884+
885+
async def test_enriches_http_error_with_provider_message(self) -> None:
886+
"""HTTP errors from stream iterators are enriched via error_handler."""
887+
888+
async def _failing_stream() -> AsyncIterator[dict[str, Any]]:
889+
response = httpx.Response(
890+
401,
891+
content=b'{"error": {"message": "Invalid API Key"}}',
892+
request=httpx.Request("POST", "https://api.example.com/v1/chat"),
893+
)
894+
raise httpx.HTTPStatusError(
895+
"Client error '401 Unauthorized'",
896+
request=response.request,
897+
response=response,
898+
)
899+
yield # type: ignore[misc] # Make this an async generator
900+
901+
def _handle_error(response: httpx.Response) -> None:
902+
error_msg = response.json()["error"]["message"]
903+
raise httpx.HTTPStatusError(
904+
f"TestProvider API error: {error_msg}",
905+
request=response.request,
906+
response=response,
907+
)
908+
909+
enriched = enrich_stream_errors(_failing_stream(), _handle_error)
910+
911+
with pytest.raises(
912+
httpx.HTTPStatusError, match="TestProvider API error: Invalid API Key"
913+
):
914+
async for _ in enriched:
915+
pass
916+
917+
async def test_passes_through_events_on_success(self) -> None:
918+
"""Successful streams pass through events unmodified."""
919+
920+
async def _ok_stream() -> AsyncIterator[dict[str, Any]]:
921+
yield {"delta": "Hello"}
922+
yield {"delta": " world"}
923+
924+
enriched = enrich_stream_errors(_ok_stream(), lambda r: None)
925+
events = [event async for event in enriched]
926+
927+
assert events == [{"delta": "Hello"}, {"delta": " world"}]
928+
929+
async def test_enriches_error_with_non_json_body(self) -> None:
930+
"""Error handler receives response even when body isn't valid JSON."""
931+
932+
async def _failing_stream() -> AsyncIterator[dict[str, Any]]:
933+
response = httpx.Response(
934+
500,
935+
content=b"Internal Server Error",
936+
request=httpx.Request("POST", "https://api.example.com/v1/chat"),
937+
)
938+
raise httpx.HTTPStatusError(
939+
"Server error '500 Internal Server Error'",
940+
request=response.request,
941+
response=response,
942+
)
943+
yield # type: ignore[misc] # Make this an async generator
944+
945+
def _handle_error(response: httpx.Response) -> None:
946+
try:
947+
error_msg = response.json()["error"]["message"]
948+
except Exception:
949+
error_msg = response.text or f"HTTP {response.status_code}"
950+
raise httpx.HTTPStatusError(
951+
f"TestProvider API error: {error_msg}",
952+
request=response.request,
953+
response=response,
954+
)
955+
956+
enriched = enrich_stream_errors(_failing_stream(), _handle_error)
957+
958+
with pytest.raises(
959+
httpx.HTTPStatusError, match="TestProvider API error: Internal Server Error"
960+
):
961+
async for _ in enriched:
962+
pass

0 commit comments

Comments
 (0)