Skip to content

Commit 5caf234

Browse files
committed
Fix RAI termination: check at application layer, not just workflow level
The previous termination condition checked at the workflow level, but by then the handoff to the next agent had already occurred. This fix: - Adds _check_message_for_rai_refusal() to check individual messages - Updates process_message() to detect RAI refusals immediately when receiving agent responses and mark them as final with return to exit the generator - Updates send_user_response() with the same RAI detection logic - Adds 'rai_blocked' flag to response to indicate RAI termination This ensures that when an agent (Triage or Planning) returns a refusal message, the workflow stops immediately without continuing to the next agent.
1 parent 439c4d5 commit 5caf234

1 file changed

Lines changed: 68 additions & 4 deletions

File tree

content-gen/src/backend/orchestrator.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,32 @@ def _check_for_rai_refusal(conversation: list) -> bool:
123123
return False
124124

125125

126+
def _check_message_for_rai_refusal(message_text: str) -> bool:
127+
"""
128+
Check if a single message indicates an RAI refusal.
129+
130+
This is used to detect refusals at the application layer and terminate
131+
the workflow immediately, without waiting for the next workflow cycle.
132+
133+
Args:
134+
message_text: The text content of the message to check
135+
136+
Returns:
137+
bool: True if an RAI refusal pattern was detected
138+
"""
139+
if not message_text:
140+
return False
141+
142+
message_lower = message_text.lower()
143+
144+
for pattern in RAI_REFUSAL_PATTERNS:
145+
if pattern in message_lower:
146+
logger.info(f"RAI refusal pattern detected in message: '{pattern}'")
147+
return True
148+
149+
return False
150+
151+
126152
# Agent system instructions
127153
TRIAGE_INSTRUCTIONS = f"""You are a Triage Agent (coordinator) for a retail marketing content generation system.
128154
@@ -624,10 +650,30 @@ async def process_message(
624650
f"{msg.author_name or msg.role.value}: {msg.text}"
625651
for msg in event.data.conversation
626652
])
653+
654+
# Get the last message content
655+
last_msg_content = event.data.conversation[-1].text if event.data.conversation else ""
656+
last_msg_agent = event.data.conversation[-1].author_name if event.data.conversation else "unknown"
657+
658+
# Check if this is an RAI refusal - if so, mark as final and don't continue
659+
is_rai_refusal = _check_message_for_rai_refusal(last_msg_content)
660+
if is_rai_refusal:
661+
logger.info(f"RAI refusal detected from {last_msg_agent}, terminating workflow")
662+
yield {
663+
"type": "agent_response",
664+
"agent": last_msg_agent,
665+
"content": last_msg_content,
666+
"conversation_history": conversation_text,
667+
"is_final": True, # Mark as final to stop workflow
668+
"rai_blocked": True, # Flag indicating RAI block
669+
"metadata": {"conversation_id": conversation_id}
670+
}
671+
return # Exit the generator to stop processing
672+
627673
yield {
628674
"type": "agent_response",
629-
"agent": event.data.conversation[-1].author_name if event.data.conversation else "unknown",
630-
"content": event.data.conversation[-1].text if event.data.conversation else "",
675+
"agent": last_msg_agent,
676+
"content": last_msg_content,
631677
"conversation_history": conversation_text,
632678
"is_final": False,
633679
"requires_user_input": True,
@@ -696,10 +742,28 @@ async def send_user_response(
696742

697743
elif isinstance(event, RequestInfoEvent):
698744
if isinstance(event.data, HandoffAgentUserRequest):
745+
# Get the last message content
746+
last_msg_content = event.data.conversation[-1].text if event.data.conversation else ""
747+
last_msg_agent = event.data.conversation[-1].author_name if event.data.conversation else "unknown"
748+
749+
# Check if this is an RAI refusal - if so, mark as final and don't continue
750+
is_rai_refusal = _check_message_for_rai_refusal(last_msg_content)
751+
if is_rai_refusal:
752+
logger.info(f"RAI refusal detected from {last_msg_agent} in user response flow, terminating workflow")
753+
yield {
754+
"type": "agent_response",
755+
"agent": last_msg_agent,
756+
"content": last_msg_content,
757+
"is_final": True, # Mark as final to stop workflow
758+
"rai_blocked": True, # Flag indicating RAI block
759+
"metadata": {"conversation_id": conversation_id}
760+
}
761+
return # Exit the generator to stop processing
762+
699763
yield {
700764
"type": "agent_response",
701-
"agent": event.data.conversation[-1].author_name if event.data.conversation else "unknown",
702-
"content": event.data.conversation[-1].text if event.data.conversation else "",
765+
"agent": last_msg_agent,
766+
"content": last_msg_content,
703767
"is_final": False,
704768
"requires_user_input": True,
705769
"request_id": event.request_id,

0 commit comments

Comments
 (0)