|
4 | 4 | from typing import Any, ClassVar, Unpack |
5 | 5 | from unittest.mock import AsyncMock |
6 | 6 |
|
| 7 | +import httpx |
7 | 8 | import pytest |
8 | 9 | from pydantic import Field |
9 | 10 |
|
10 | 11 | from celeste.exceptions import StreamEventError, StreamNotExhaustedError |
11 | 12 | from celeste.io import Chunk, FinishReason, Output, Usage |
12 | 13 | from celeste.parameters import Parameters |
13 | | -from celeste.streaming import Stream |
| 14 | +from celeste.streaming import Stream, enrich_stream_errors |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class ConcreteOutput(Output[str]): |
@@ -876,3 +877,86 @@ async def test_error_provides_full_event_data(self) -> None: |
876 | 877 | async for _ in stream: |
877 | 878 | pass |
878 | 879 | assert exc_info.value.event_data == event |
| 880 | + |
| 881 | + |
| 882 | +class TestEnrichStreamErrors: |
| 883 | + """Test enrich_stream_errors wraps streaming HTTP errors with provider messages.""" |
| 884 | + |
| 885 | + async def test_enriches_http_error_with_provider_message(self) -> None: |
| 886 | + """HTTP errors from stream iterators are enriched via error_handler.""" |
| 887 | + |
| 888 | + async def _failing_stream() -> AsyncIterator[dict[str, Any]]: |
| 889 | + response = httpx.Response( |
| 890 | + 401, |
| 891 | + content=b'{"error": {"message": "Invalid API Key"}}', |
| 892 | + request=httpx.Request("POST", "https://api.example.com/v1/chat"), |
| 893 | + ) |
| 894 | + raise httpx.HTTPStatusError( |
| 895 | + "Client error '401 Unauthorized'", |
| 896 | + request=response.request, |
| 897 | + response=response, |
| 898 | + ) |
| 899 | + yield # type: ignore[misc] # Make this an async generator |
| 900 | + |
| 901 | + def _handle_error(response: httpx.Response) -> None: |
| 902 | + error_msg = response.json()["error"]["message"] |
| 903 | + raise httpx.HTTPStatusError( |
| 904 | + f"TestProvider API error: {error_msg}", |
| 905 | + request=response.request, |
| 906 | + response=response, |
| 907 | + ) |
| 908 | + |
| 909 | + enriched = enrich_stream_errors(_failing_stream(), _handle_error) |
| 910 | + |
| 911 | + with pytest.raises( |
| 912 | + httpx.HTTPStatusError, match="TestProvider API error: Invalid API Key" |
| 913 | + ): |
| 914 | + async for _ in enriched: |
| 915 | + pass |
| 916 | + |
| 917 | + async def test_passes_through_events_on_success(self) -> None: |
| 918 | + """Successful streams pass through events unmodified.""" |
| 919 | + |
| 920 | + async def _ok_stream() -> AsyncIterator[dict[str, Any]]: |
| 921 | + yield {"delta": "Hello"} |
| 922 | + yield {"delta": " world"} |
| 923 | + |
| 924 | + enriched = enrich_stream_errors(_ok_stream(), lambda r: None) |
| 925 | + events = [event async for event in enriched] |
| 926 | + |
| 927 | + assert events == [{"delta": "Hello"}, {"delta": " world"}] |
| 928 | + |
| 929 | + async def test_enriches_error_with_non_json_body(self) -> None: |
| 930 | + """Error handler receives response even when body isn't valid JSON.""" |
| 931 | + |
| 932 | + async def _failing_stream() -> AsyncIterator[dict[str, Any]]: |
| 933 | + response = httpx.Response( |
| 934 | + 500, |
| 935 | + content=b"Internal Server Error", |
| 936 | + request=httpx.Request("POST", "https://api.example.com/v1/chat"), |
| 937 | + ) |
| 938 | + raise httpx.HTTPStatusError( |
| 939 | + "Server error '500 Internal Server Error'", |
| 940 | + request=response.request, |
| 941 | + response=response, |
| 942 | + ) |
| 943 | + yield # type: ignore[misc] # Make this an async generator |
| 944 | + |
| 945 | + def _handle_error(response: httpx.Response) -> None: |
| 946 | + try: |
| 947 | + error_msg = response.json()["error"]["message"] |
| 948 | + except Exception: |
| 949 | + error_msg = response.text or f"HTTP {response.status_code}" |
| 950 | + raise httpx.HTTPStatusError( |
| 951 | + f"TestProvider API error: {error_msg}", |
| 952 | + request=response.request, |
| 953 | + response=response, |
| 954 | + ) |
| 955 | + |
| 956 | + enriched = enrich_stream_errors(_failing_stream(), _handle_error) |
| 957 | + |
| 958 | + with pytest.raises( |
| 959 | + httpx.HTTPStatusError, match="TestProvider API error: Internal Server Error" |
| 960 | + ): |
| 961 | + async for _ in enriched: |
| 962 | + pass |
0 commit comments