|
1 | 1 | """Streaming support for Celeste.""" |
2 | 2 |
|
3 | 3 | from abc import ABC, abstractmethod |
4 | | -from collections.abc import AsyncIterator, Callable |
5 | | -from contextlib import AbstractContextManager, suppress |
| 4 | +from collections.abc import AsyncIterator, Callable, Iterator |
| 5 | +from contextlib import suppress |
6 | 6 | from types import TracebackType |
7 | 7 | from typing import Any, ClassVar, Self, Unpack |
8 | 8 |
|
9 | 9 | import httpx |
10 | | -from anyio.from_thread import BlockingPortal, start_blocking_portal |
| 10 | +from anyio.from_thread import start_blocking_portal |
11 | 11 |
|
12 | 12 | from celeste.exceptions import StreamEventError, StreamNotExhaustedError |
13 | 13 | from celeste.io import Chunk as ChunkBase |
@@ -63,9 +63,8 @@ def __init__( |
63 | 63 | self._parameters = parameters |
64 | 64 | self._transform_output = transform_output |
65 | 65 | self._stream_metadata = stream_metadata or {} |
66 | | - # Sync iteration state |
67 | | - self._portal: BlockingPortal | None = None |
68 | | - self._portal_cm: AbstractContextManager[BlockingPortal] | None = None |
| 66 | + # Sync iteration state (portal lifecycle managed by __iter__ generator) |
| 67 | + self._sync_generator: Iterator[Chunk] | None = None |
69 | 68 |
|
70 | 69 | def _build_error_from_value(self, error: Any) -> dict[str, Any]: # noqa: ANN401 |
71 | 70 | """Extract {type, message} from an error value using _error_type_fields.""" |
@@ -267,38 +266,26 @@ async def __anext__(self) -> Chunk: |
267 | 266 | raise StopAsyncIteration |
268 | 267 |
|
269 | 268 | # Iterator protocol (sync) |
270 | | - def __iter__(self) -> Self: |
271 | | - """Return self as sync iterator with dedicated event loop. |
| 269 | + def __iter__(self) -> Iterator[Chunk]: |
| 270 | + """Sync iterator using a generator with try/finally for guaranteed cleanup. |
272 | 271 |
|
273 | | - Creates a blocking portal that maintains a persistent event loop |
274 | | - in a dedicated thread for consistent async context. |
| 272 | + Creates a blocking portal in a dedicated thread. The generator's finally |
| 273 | + block ensures the portal is cleaned up on exhaustion, break, exception, |
| 274 | + or garbage collection — unlike __next__ which only cleans up on exhaustion. |
275 | 275 | """ |
276 | | - if self._portal is None: |
277 | | - self._portal_cm = start_blocking_portal() |
278 | | - self._portal = self._portal_cm.__enter__() |
279 | | - return self |
280 | | - |
281 | | - def __next__(self) -> Chunk: |
282 | | - """Yield next chunk via portal's persistent event loop.""" |
283 | | - if self._portal is None: |
284 | | - self.__iter__() |
285 | | - |
| 276 | + portal_cm = start_blocking_portal() |
| 277 | + portal = portal_cm.__enter__() |
286 | 278 | try: |
287 | | - return self._portal.call(self.__anext__) # type: ignore[union-attr,no-any-return] |
288 | | - except StopAsyncIteration: |
289 | | - self._cleanup_portal() |
290 | | - raise StopIteration from None |
291 | | - |
292 | | - def _cleanup_portal(self) -> None: |
293 | | - """Clean up the blocking portal and its thread.""" |
294 | | - if self._portal_cm is not None: |
295 | | - # Close stream via portal before exiting (ensures _closed = True) |
296 | | - if self._portal is not None and not self._closed: |
| 279 | + while True: |
| 280 | + try: |
| 281 | + yield portal.call(self.__anext__) |
| 282 | + except StopAsyncIteration: |
| 283 | + return |
| 284 | + finally: |
| 285 | + if not self._closed: |
297 | 286 | with suppress(RuntimeError): |
298 | | - self._portal.call(self.aclose) |
299 | | - self._portal_cm.__exit__(None, None, None) |
300 | | - self._portal = None |
301 | | - self._portal_cm = None |
| 287 | + portal.call(self.aclose) |
| 288 | + portal_cm.__exit__(None, None, None) |
302 | 289 |
|
303 | 290 | # AsyncContextManager protocol |
304 | 291 | async def __aenter__(self) -> Self: |
@@ -326,8 +313,8 @@ def __exit__( |
326 | 313 | exc_val: BaseException | None, |
327 | 314 | exc_tb: TracebackType | None, |
328 | 315 | ) -> None: |
329 | | - """Exit sync context - ensure cleanup.""" |
330 | | - self._cleanup_portal() |
| 316 | + """Exit sync context. Portal cleanup is handled by __iter__ generator's finally.""" |
| 317 | + return |
331 | 318 |
|
332 | 319 | @property |
333 | 320 | def output(self) -> Out: |
|
0 commit comments