Skip to content

Commit 8789c28

Browse files
fix: BlockingPortal thread leak in sync stream iteration (#239)
* fix: BlockingPortal thread leak in sync stream iteration (#238) Replace __iter__/__next__ with generator-based __iter__ using try/finally. Python guarantees generator finally blocks run on exhaustion, break, exception, and GC — fixing the portal leak that caused CI to hang. Also fix Pydantic V2.11 deprecation: tool.model_fields → type(tool).model_fields in 4 tool mapper files. * fix: handle null tool_calls in ChatCompletions streaming delta DeepSeek sends "tool_calls": null instead of omitting the key. dict.get("tool_calls", []) returns None (key exists, value null). Changed to `or []` to handle both missing and null.
1 parent f08c3d5 commit 8789c28

6 files changed

Lines changed: 28 additions & 41 deletions

File tree

src/celeste/protocols/chatcompletions/streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _parse_chunk(self, event_data: dict[str, Any]) -> Any: # noqa: ANN401
8989
if choices and isinstance(choices[0], dict):
9090
delta = choices[0].get("delta", {})
9191
if isinstance(delta, dict):
92-
for tc_delta in delta.get("tool_calls", []):
92+
for tc_delta in delta.get("tool_calls") or []:
9393
idx = tc_delta.get("index", 0)
9494
if idx not in self._tool_call_deltas:
9595
self._tool_call_deltas[idx] = {

src/celeste/protocols/openresponses/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def map_tool(self, tool: Tool) -> dict[str, Any]:
2828
result: dict[str, Any] = {"type": "web_search"}
2929
if tool.allowed_domains is not None:
3030
result.setdefault("filters", {})["allowed_domains"] = tool.allowed_domains
31-
for field in tool.model_fields:
31+
for field in type(tool).model_fields:
3232
if field not in self._supported_fields and getattr(tool, field) is not None:
3333
warnings.warn(
3434
f"WebSearch.{field} is not supported by OpenResponses "

src/celeste/providers/google/generate_content/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def map_tool(self, tool: Tool) -> dict[str, Any]:
1818
config: dict[str, Any] = {}
1919
if tool.blocked_domains is not None:
2020
config["exclude_domains"] = tool.blocked_domains
21-
for field in tool.model_fields:
21+
for field in type(tool).model_fields:
2222
if field not in self._supported_fields and getattr(tool, field) is not None:
2323
warnings.warn(
2424
f"WebSearch.{field} is not supported by Google "

src/celeste/providers/groq/chat/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class WebSearchMapper(ToolMapper):
1515

1616
def map_tool(self, tool: Tool) -> dict[str, Any]:
1717
assert isinstance(tool, WebSearch)
18-
for field in tool.model_fields:
18+
for field in type(tool).model_fields:
1919
if field not in self._supported_fields and getattr(tool, field) is not None:
2020
warnings.warn(
2121
f"WebSearch.{field} is not supported by Groq "

src/celeste/providers/moonshot/chat/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class WebSearchMapper(ToolMapper):
1515

1616
def map_tool(self, tool: Tool) -> dict[str, Any]:
1717
assert isinstance(tool, WebSearch)
18-
for field in tool.model_fields:
18+
for field in type(tool).model_fields:
1919
if field not in self._supported_fields and getattr(tool, field) is not None:
2020
warnings.warn(
2121
f"WebSearch.{field} is not supported by Moonshot "

src/celeste/streaming.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Streaming support for Celeste."""
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import AsyncIterator, Callable
5-
from contextlib import AbstractContextManager, suppress
4+
from collections.abc import AsyncIterator, Callable, Iterator
5+
from contextlib import suppress
66
from types import TracebackType
77
from typing import Any, ClassVar, Self, Unpack
88

99
import httpx
10-
from anyio.from_thread import BlockingPortal, start_blocking_portal
10+
from anyio.from_thread import start_blocking_portal
1111

1212
from celeste.exceptions import StreamEventError, StreamNotExhaustedError
1313
from celeste.io import Chunk as ChunkBase
@@ -63,9 +63,8 @@ def __init__(
6363
self._parameters = parameters
6464
self._transform_output = transform_output
6565
self._stream_metadata = stream_metadata or {}
66-
# Sync iteration state
67-
self._portal: BlockingPortal | None = None
68-
self._portal_cm: AbstractContextManager[BlockingPortal] | None = None
66+
# Sync iteration state (portal lifecycle managed by __iter__ generator)
67+
self._sync_generator: Iterator[Chunk] | None = None
6968

7069
def _build_error_from_value(self, error: Any) -> dict[str, Any]: # noqa: ANN401
7170
"""Extract {type, message} from an error value using _error_type_fields."""
@@ -267,38 +266,26 @@ async def __anext__(self) -> Chunk:
267266
raise StopAsyncIteration
268267

269268
# Iterator protocol (sync)
270-
def __iter__(self) -> Self:
271-
"""Return self as sync iterator with dedicated event loop.
269+
def __iter__(self) -> Iterator[Chunk]:
270+
"""Sync iterator using a generator with try/finally for guaranteed cleanup.
272271
273-
Creates a blocking portal that maintains a persistent event loop
274-
in a dedicated thread for consistent async context.
272+
Creates a blocking portal in a dedicated thread. The generator's finally
273+
block ensures the portal is cleaned up on exhaustion, break, exception,
274+
or garbage collection — unlike __next__ which only cleans up on exhaustion.
275275
"""
276-
if self._portal is None:
277-
self._portal_cm = start_blocking_portal()
278-
self._portal = self._portal_cm.__enter__()
279-
return self
280-
281-
def __next__(self) -> Chunk:
282-
"""Yield next chunk via portal's persistent event loop."""
283-
if self._portal is None:
284-
self.__iter__()
285-
276+
portal_cm = start_blocking_portal()
277+
portal = portal_cm.__enter__()
286278
try:
287-
return self._portal.call(self.__anext__) # type: ignore[union-attr,no-any-return]
288-
except StopAsyncIteration:
289-
self._cleanup_portal()
290-
raise StopIteration from None
291-
292-
def _cleanup_portal(self) -> None:
293-
"""Clean up the blocking portal and its thread."""
294-
if self._portal_cm is not None:
295-
# Close stream via portal before exiting (ensures _closed = True)
296-
if self._portal is not None and not self._closed:
279+
while True:
280+
try:
281+
yield portal.call(self.__anext__)
282+
except StopAsyncIteration:
283+
return
284+
finally:
285+
if not self._closed:
297286
with suppress(RuntimeError):
298-
self._portal.call(self.aclose)
299-
self._portal_cm.__exit__(None, None, None)
300-
self._portal = None
301-
self._portal_cm = None
287+
portal.call(self.aclose)
288+
portal_cm.__exit__(None, None, None)
302289

303290
# AsyncContextManager protocol
304291
async def __aenter__(self) -> Self:
@@ -326,8 +313,8 @@ def __exit__(
326313
exc_val: BaseException | None,
327314
exc_tb: TracebackType | None,
328315
) -> None:
329-
"""Exit sync context - ensure cleanup."""
330-
self._cleanup_portal()
316+
"""Exit sync context. Portal cleanup is handled by __iter__ generator's finally."""
317+
return
331318

332319
@property
333320
def output(self) -> Out:

0 commit comments

Comments
 (0)