Skip to content

Commit b181150

Browse files
committed
Add streaming agent response callback to handoff
1 parent d56aabd commit b181150

File tree

2 files changed

+104
-75
lines changed

2 files changed

+104
-75
lines changed

python/semantic_kernel/agents/orchestration/agent_actor_base.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -98,55 +98,55 @@ async def _call_streaming_agent_response_callback(
9898
else:
9999
self._streaming_agent_response_callback(message_chunk, is_final)
100100

101-
async def _invoke_agent(self, additional_messages: DefaultTypeAlias | None = None) -> ChatMessageContent:
101+
async def _invoke_agent(self, additional_messages: DefaultTypeAlias | None = None, **kwargs) -> ChatMessageContent:
102102
"""Invoke the agent with the current chat history or thread and optionally additional messages.
103103
104104
Args:
105105
additional_messages (DefaultTypeAlias | None): Additional messages to be sent to the agent.
106+
**kwargs: Additional keyword arguments to be passed to the agent's invoke method:
107+
- kernel: The kernel to use for the agent invocation.
106108
107109
Returns:
108110
DefaultTypeAlias: The response from the agent.
109111
"""
110112
streaming_message_buffer: list[StreamingChatMessageContent] = []
113+
messages = self._create_messages(additional_messages)
111114

112-
if self._agent_thread is None:
113-
messages = self._chat_history.messages[:]
114-
if additional_messages:
115-
messages.extend(additional_messages if isinstance(additional_messages, list) else [additional_messages])
116-
117-
async for response_item in self._agent.invoke_stream(messages=messages):
118-
# Buffer message chunks and stream them with correct is_final flag.
119-
streaming_message_buffer.append(response_item.message)
120-
if len(streaming_message_buffer) > 1:
121-
await self._call_streaming_agent_response_callback(streaming_message_buffer[-2], is_final=False)
122-
if self._agent_thread is None:
123-
self._agent_thread = response_item.thread
124-
if streaming_message_buffer:
125-
await self._call_streaming_agent_response_callback(streaming_message_buffer[-1], is_final=True)
126-
else:
127-
messages = (
128-
[]
129-
if additional_messages is None
130-
else additional_messages
131-
if isinstance(additional_messages, list)
132-
else [additional_messages]
133-
)
134-
135-
async for response_item in self._agent.invoke_stream(
136-
messages=messages,
137-
thread=self._agent_thread,
138-
):
139-
# Buffer message chunks and stream them with correct is_final flag.
140-
streaming_message_buffer.append(response_item.message)
141-
if len(streaming_message_buffer) > 1:
142-
await self._call_streaming_agent_response_callback(streaming_message_buffer[-2], is_final=False)
143-
if streaming_message_buffer:
144-
await self._call_streaming_agent_response_callback(streaming_message_buffer[-1], is_final=True)
115+
async for response_item in self._agent.invoke_stream(messages=messages, thread=self._agent_thread, **kwargs):
116+
# Buffer message chunks and stream them with correct is_final flag.
117+
streaming_message_buffer.append(response_item.message)
118+
if len(streaming_message_buffer) > 1:
119+
await self._call_streaming_agent_response_callback(streaming_message_buffer[-2], is_final=False)
120+
if self._agent_thread is None:
121+
self._agent_thread = response_item.thread
122+
123+
if streaming_message_buffer:
124+
# Call the callback for the last message chunk with is_final=True.
125+
await self._call_streaming_agent_response_callback(streaming_message_buffer[-1], is_final=True)
145126

146127
if not streaming_message_buffer:
147128
raise RuntimeError(f'Agent "{self._agent.name}" did not return any response.')
148129

130+
# Build the full response from the streaming messages
149131
full_response = sum(streaming_message_buffer[1:], streaming_message_buffer[0])
150132
await self._call_agent_response_callback(full_response)
151133

152134
return full_response
135+
136+
def _create_messages(self, additional_messages: DefaultTypeAlias | None = None) -> list[ChatMessageContent]:
137+
"""Create a list of messages to be sent to the agent along with a potential thread.
138+
139+
Args:
140+
additional_messages (DefaultTypeAlias | None): Additional messages to be sent to the agent.
141+
142+
Returns:
143+
list[ChatMessageContent]: A list of messages to be sent to the agent.
144+
"""
145+
base_messages = self._chat_history.messages[:] if self._agent_thread is None else []
146+
147+
if additional_messages is None:
148+
return base_messages
149+
150+
if isinstance(additional_messages, list):
151+
return base_messages + additional_messages
152+
return [*base_messages, additional_messages]

python/semantic_kernel/agents/orchestration/handoffs.py

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from semantic_kernel.agents.runtime.core.topic import TopicId
1919
from semantic_kernel.agents.runtime.in_process.type_subscription import TypeSubscription
2020
from semantic_kernel.contents.chat_message_content import ChatMessageContent
21+
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
2122
from semantic_kernel.contents.utils.author_role import AuthorRole
2223
from semantic_kernel.filters.auto_function_invocation.auto_function_invocation_context import (
2324
AutoFunctionInvocationContext,
@@ -162,6 +163,8 @@ def __init__(
162163
handoff_connections: AgentHandoffs,
163164
result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None,
164165
agent_response_callback: Callable[[DefaultTypeAlias], Awaitable[None] | None] | None = None,
166+
streaming_agent_response_callback: Callable[[StreamingChatMessageContent, bool], Awaitable[None] | None]
167+
| None = None,
165168
human_response_function: Callable[[], Awaitable[ChatMessageContent] | ChatMessageContent] | None = None,
166169
) -> None:
167170
"""Initialize the handoff agent actor."""
@@ -179,6 +182,7 @@ def __init__(
179182
agent=agent,
180183
internal_topic_type=internal_topic_type,
181184
agent_response_callback=agent_response_callback,
185+
streaming_agent_response_callback=streaming_agent_response_callback,
182186
)
183187

184188
def _add_handoff_functions(self) -> None:
@@ -275,62 +279,44 @@ async def _handle_request_message(self, message: HandoffRequestMessage, cts: Mes
275279
return
276280
logger.debug(f"{self.id}: Received handoff request message.")
277281

278-
if self._agent_thread is None:
279-
self._chat_history.add_message(
280-
ChatMessageContent(
281-
role=AuthorRole.USER,
282-
content=f"Transferred to {self._agent.name}, adopt the persona immediately.",
283-
)
284-
)
285-
response_item = await self._agent.get_response(
286-
messages=self._chat_history.messages, # type: ignore[arg-type]
287-
kernel=self._kernel,
288-
)
289-
else:
290-
response_item = await self._agent.get_response(
291-
messages=ChatMessageContent(
292-
role=AuthorRole.USER,
293-
content=f"Transferred to {self._agent.name}, adopt the persona immediately.",
294-
),
295-
thread=self._agent_thread,
296-
kernel=self._kernel,
297-
)
298-
299-
if self._agent_thread is None:
300-
self._agent_thread = response_item.thread
282+
persona_adoption_message = ChatMessageContent(
283+
role=AuthorRole.USER,
284+
content=f"Transferred to {self._agent.name}, adopt the persona immediately.",
285+
)
286+
response = await self._invoke_agent_with_potentially_no_response(
287+
additional_messages=persona_adoption_message,
288+
kernel=self._kernel,
289+
)
301290

302291
while not self._task_completed:
303-
if response_item.message.role == AuthorRole.ASSISTANT:
304-
# The response can potentially be a TOOL message from the Handoff plugin
305-
# since we have added a filter which will terminate the conversation when
306-
# a function from the handoff plugin is called. And we don't want to publish
307-
# that message. So we only publish if the response is an ASSISTANT message.
308-
logger.debug(f"{self.id} responded with: {response_item.message.content}")
309-
await self._call_agent_response_callback(response_item.message)
310-
311-
await self.publish_message(
312-
HandoffResponseMessage(body=response_item.message),
313-
TopicId(self._internal_topic_type, self.id.key),
314-
cancellation_token=cts.cancellation_token,
315-
)
316-
317292
if self._handoff_agent_name:
318293
await self.publish_message(
319294
HandoffRequestMessage(agent_name=self._handoff_agent_name),
320295
TopicId(self._internal_topic_type, self.id.key),
321296
)
322297
self._handoff_agent_name = None
323298
break
299+
300+
if response is None:
301+
raise RuntimeError(
302+
f'Agent "{self._agent.name}" did not return any response nor did not set a handoff agent name.'
303+
)
304+
305+
await self.publish_message(
306+
HandoffResponseMessage(body=response),
307+
TopicId(self._internal_topic_type, self.id.key),
308+
cancellation_token=cts.cancellation_token,
309+
)
310+
324311
if self._human_response_function:
325312
human_response = await self._call_human_response_function()
326313
await self.publish_message(
327314
HandoffResponseMessage(body=human_response),
328315
TopicId(self._internal_topic_type, self.id.key),
329316
cancellation_token=cts.cancellation_token,
330317
)
331-
response_item = await self._agent.get_response(
332-
messages=human_response,
333-
thread=self._agent_thread,
318+
response = await self._invoke_agent_with_potentially_no_response(
319+
additional_messages=human_response,
334320
kernel=self._kernel,
335321
)
336322
else:
@@ -346,6 +332,43 @@ async def _call_human_response_function(self) -> ChatMessageContent:
346332
return await self._human_response_function()
347333
return self._human_response_function() # type: ignore[return-value]
348334

335+
async def _invoke_agent_with_potentially_no_response(
336+
self, additional_messages: DefaultTypeAlias | None = None, **kwargs
337+
) -> ChatMessageContent | None:
338+
"""Invoke the agent with the current chat history or thread and optionally additional messages.
339+
340+
This method differs from `_invoke_agent` in that it handles the case where no response is returned
341+
from the agent gracefully, returning `None` instead of raising an error.
342+
343+
The reason for this is that agents in a handoff group chat may not always produce a response when
344+
a handoff function is invoked, where the `_handoff_function_filter` will terminate the auto function
345+
invocation loop before a response is produced. In such cases, this method will return `None`
346+
instead of raising an error.
347+
"""
348+
streaming_message_buffer: list[StreamingChatMessageContent] = []
349+
messages = self._create_messages(additional_messages)
350+
351+
async for response_item in self._agent.invoke_stream(messages=messages, thread=self._agent_thread, **kwargs):
352+
# Buffer message chunks and stream them with correct is_final flag.
353+
streaming_message_buffer.append(response_item.message)
354+
if len(streaming_message_buffer) > 1:
355+
await self._call_streaming_agent_response_callback(streaming_message_buffer[-2], is_final=False)
356+
if self._agent_thread is None:
357+
self._agent_thread = response_item.thread
358+
359+
if streaming_message_buffer:
360+
# Call the callback for the last message chunk with is_final=True.
361+
await self._call_streaming_agent_response_callback(streaming_message_buffer[-1], is_final=True)
362+
363+
if not streaming_message_buffer:
364+
return None
365+
366+
# Build the full response from the streaming messages
367+
full_response = sum(streaming_message_buffer[1:], streaming_message_buffer[0])
368+
await self._call_agent_response_callback(full_response)
369+
370+
return full_response
371+
349372

350373
# endregion HandoffAgentActor
351374

@@ -365,6 +388,8 @@ def __init__(
365388
input_transform: Callable[[TIn], Awaitable[DefaultTypeAlias] | DefaultTypeAlias] | None = None,
366389
output_transform: Callable[[DefaultTypeAlias], Awaitable[TOut] | TOut] | None = None,
367390
agent_response_callback: Callable[[DefaultTypeAlias], Awaitable[None] | None] | None = None,
391+
streaming_agent_response_callback: Callable[[StreamingChatMessageContent, bool], Awaitable[None] | None]
392+
| None = None,
368393
human_response_function: Callable[[], Awaitable[ChatMessageContent] | ChatMessageContent] | None = None,
369394
) -> None:
370395
"""Initialize the handoff orchestration.
@@ -377,8 +402,10 @@ def __init__(
377402
description (str | None): The description of the orchestration.
378403
input_transform (Callable | None): A function that transforms the external input message.
379404
output_transform (Callable | None): A function that transforms the internal output message.
380-
agent_response_callback (Callable | None): A function that is called when a response is produced
405+
agent_response_callback (Callable | None): A function that is called when a full response is produced
381406
by the agents.
407+
streaming_agent_response_callback (Callable | None): A function that is called when a streaming response
408+
is produced by the agents.
382409
human_response_function (Callable | None): A function that is called when a human response is
383410
needed.
384411
"""
@@ -392,6 +419,7 @@ def __init__(
392419
input_transform=input_transform,
393420
output_transform=output_transform,
394421
agent_response_callback=agent_response_callback,
422+
streaming_agent_response_callback=streaming_agent_response_callback,
395423
)
396424

397425
self._validate_handoffs()
@@ -461,6 +489,7 @@ async def _register_helper(agent: Agent) -> None:
461489
handoff_connections,
462490
result_callback=result_callback,
463491
agent_response_callback=self._agent_response_callback,
492+
streaming_agent_response_callback=self._streaming_agent_response_callback,
464493
human_response_function=self._human_response_function,
465494
),
466495
)

0 commit comments

Comments
 (0)