Skip to content

Commit c513f4b

Browse files
Add user state cleanup and enhance AzureAIClient initialization with deployment name fallback
1 parent 3a4dfbb commit c513f4b

4 files changed

Lines changed: 64 additions & 40 deletions

File tree

src/backend/v4/config/settings.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,26 @@ def cleanup_clarification(self, request_id: str) -> None:
220220
self.clarifications.pop(request_id, None)
221221
self._clarification_events.pop(request_id, None)
222222

223+
def cleanup_user_state(self, user_id: str) -> None:
224+
"""Clean up all state for a user to prevent cross-scenario leakage.
225+
226+
This removes any pending approvals, clarifications, and plans
227+
associated with the user to ensure fresh state for new runs.
228+
"""
229+
# Clean up any plans associated with this user
230+
plans_to_remove = [
231+
plan_id for plan_id, plan in self.plans.items()
232+
if getattr(plan, 'user_id', None) == user_id
233+
]
234+
for plan_id in plans_to_remove:
235+
self.plans.pop(plan_id, None)
236+
self.cleanup_approval(plan_id)
237+
238+
# Clean up any pending approvals/clarifications for this user
239+
# Note: We can't easily map approvals to users without plan context,
240+
# so this primarily clears the plans and their associated approvals
241+
logger.debug("Cleaned up state for user %s (removed %d plans)", user_id, len(plans_to_remove))
242+
223243

224244
class ConnectionConfig:
225245
"""Connection manager for WebSocket connections."""

src/backend/v4/magentic_agents/common/lifecycle.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from agent_framework_azure_ai import AzureAIClient
1515
from azure.ai.agents.aio import AgentsClient
1616
from azure.identity.aio import DefaultAzureCredential
17+
from common.config.app_config import config
1718
from common.database.database_base import DatabaseBase
1819
from common.models.messages_af import CurrentTeamAgent, TeamConfiguration
1920
from common.utils.utils_agents import (
@@ -160,10 +161,12 @@ def get_chat_client(self, chat_client) -> AzureAIClient:
160161
and self._agent.chat_client
161162
):
162163
return self._agent.chat_client # type: ignore
164+
# Use model_deployment_name with fallback to default model if empty
165+
deployment_name = self.model_deployment_name or config.AZURE_OPENAI_DEPLOYMENT_NAME
163166
chat_client = AzureAIClient(
164167
project_endpoint=self.project_endpoint,
165168
agent_name=self.agent_name,
166-
model_deployment_name=self.model_deployment_name,
169+
model_deployment_name=deployment_name,
167170
credential=self.creds,
168171
use_latest_version=True,
169172
)
@@ -277,20 +280,26 @@ async def get_database_team_agent(self) -> Optional[AzureAIClient]:
277280

278281
# Create client with resolved ID
279282
if self.agent_name == "RAIAgent" and self.project_client:
283+
# Use RAI deployment name for RAI agents
284+
rai_deployment = config.AZURE_OPENAI_RAI_DEPLOYMENT_NAME
280285
chat_client = AzureAIClient(
281286
project_endpoint=self.project_endpoint,
282287
agent_id=resolved,
288+
model_deployment_name=rai_deployment,
283289
credential=self.creds,
284290
)
285291
self.logger.info(
286-
"RAI.AgentReuseSuccess: Created AzureAIClient (id=%s)",
292+
"RAI.AgentReuseSuccess: Created AzureAIClient (id=%s, model=%s)",
287293
resolved,
294+
rai_deployment,
288295
)
289296
else:
297+
# Use model_deployment_name with fallback to default model if empty
298+
deployment_name = self.model_deployment_name or config.AZURE_OPENAI_DEPLOYMENT_NAME
290299
chat_client = AzureAIClient(
291300
project_endpoint=self.project_endpoint,
292301
agent_id=resolved,
293-
model_deployment_name=self.model_deployment_name,
302+
model_deployment_name=deployment_name,
294303
credential=self.creds,
295304
)
296305
self.logger.info(

src/backend/v4/magentic_agents/foundry_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,13 @@ async def _create_azure_search_enabled_client(self, chatClient=None) -> Optional
222222
)
223223

224224
# Wrap in AzureAIClient using agent_name and agent_version (NOT agent_id)
225-
# AzureAIClient constructor: agent_name, agent_version, project_endpoint, credential
225+
# Include model_deployment_name to ensure SDK has model info for streaming
226+
deployment_name = self.model_deployment_name or config.AZURE_OPENAI_DEPLOYMENT_NAME
226227
chat_client = AzureAIClient(
227228
project_endpoint=self.project_endpoint,
228229
agent_name=azure_agent.name,
229230
agent_version=azure_agent.version, # Use the specific version we just created
231+
model_deployment_name=deployment_name,
230232
credential=self.creds,
231233
)
232234
return chat_client

src/backend/v4/orchestration/orchestration_manager.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ async def init_orchestration(
133133

134134
try:
135135
# Create the chat client (AzureAIClient)
136+
# Use team deployment_name with fallback to default model if empty
137+
deployment_name = team_config.deployment_name or config.AZURE_OPENAI_DEPLOYMENT_NAME
136138
chat_client = AzureAIClient(
137139
project_endpoint=config.AZURE_AI_PROJECT_ENDPOINT,
138-
model_deployment_name=team_config.deployment_name,
140+
model_deployment_name=deployment_name,
139141
agent_name=agent_name,
140142
credential=credential,
141143
)
@@ -150,7 +152,7 @@ async def init_orchestration(
150152

151153
cls.logger.info(
152154
"Created AzureAIClient and manager ChatAgent for orchestration with model '%s' at endpoint '%s'",
153-
team_config.deployment_name,
155+
deployment_name,
154156
config.AZURE_AI_PROJECT_ENDPOINT,
155157
)
156158
except Exception as e:
@@ -197,19 +199,17 @@ async def init_orchestration(
197199
# Assemble workflow with callback
198200
storage = InMemoryCheckpointStorage()
199201

200-
# New SDK: participants() accepts a Sequence (list) of agents
201-
# The orchestrator uses agent.name to identify them
202+
# New API: .participants() accepts a list of agents
202203
participant_list = list(participants.values())
203-
cls.logger.info("Participants for workflow: %s", list(participants.keys()))
204-
print(f"[DEBUG] Participants for workflow: {list(participants.keys())}", flush=True)
205204

206205
builder = (
207206
MagenticBuilder()
208-
.participants(participant_list) # New SDK: pass as list
207+
.participants(participant_list)
209208
.with_manager(
210209
manager=manager, # Pass manager instance (extends StandardMagenticManager)
211210
max_round_count=orchestration_config.max_rounds,
212-
max_stall_count=0, # CRITICAL: Prevent re-calling agents when stalled (default is 3!)
211+
max_stall_count=3,
212+
max_reset_count=2,
213213
)
214214
.with_checkpointing(storage)
215215
)
@@ -239,16 +239,14 @@ async def get_current_or_new_orchestration(
239239
Return an existing workflow for the user or create a new one if:
240240
- None exists
241241
- Team switched flag is True
242-
- force_rebuild is True (for new tasks after workflow completion)
242+
- force_rebuild is True (for new tasks that need fresh workflow)
243243
"""
244244
current = orchestration_config.get_current_orchestration(user_id)
245-
needs_rebuild = current is None or team_switched or force_rebuild
246-
247-
if needs_rebuild:
245+
if current is None or team_switched or force_rebuild:
248246
if current is not None and (team_switched or force_rebuild):
249-
reason = "team switched" if team_switched else "force rebuild for new task"
247+
reason = "team switched" if team_switched else "force rebuild"
250248
cls.logger.info(
251-
"Rebuilding orchestration for user '%s' (reason: %s)", user_id, reason
249+
"Closing previous agents for user '%s' (reason: %s)", user_id, reason
252250
)
253251
# Close prior agents (same logic as old version)
254252
for agent in getattr(current, "_participants", {}).values():
@@ -305,6 +303,11 @@ async def run_orchestration(self, user_id: str, input_task) -> None:
305303
Execute the Magentic workflow for the provided user and task description.
306304
"""
307305
job_id = str(uuid.uuid4())
306+
307+
# Clean up any accumulated state from previous runs (cancelled plans, etc.)
308+
# This prevents cross-scenario leakage
309+
orchestration_config.cleanup_user_state(user_id)
310+
308311
orchestration_config.set_approval_pending(job_id)
309312
self.logger.info(
310313
"Starting orchestration job '%s' for user '%s'", job_id, user_id
@@ -317,6 +320,16 @@ async def run_orchestration(self, user_id: str, input_task) -> None:
317320
if workflow is None:
318321
print(f"[ERROR] Orchestration not initialized for user '{user_id}'")
319322
raise ValueError("Orchestration not initialized for user.")
323+
324+
# Reset manager's plan state to prevent leakage from cancelled plans
325+
manager = getattr(workflow, "_manager", None)
326+
if manager and hasattr(manager, "magentic_plan"):
327+
manager.magentic_plan = None
328+
self.logger.debug("Reset manager's magentic_plan for fresh run")
329+
if manager and hasattr(manager, "task_ledger"):
330+
manager.task_ledger = None
331+
self.logger.debug("Reset manager's task_ledger for fresh run")
332+
320333
# Fresh thread per participant to avoid cross-run state bleed
321334
executors = getattr(workflow, "executors", {})
322335
self.logger.debug("Executor keys at run start: %s", list(executors.keys()))
@@ -383,16 +396,12 @@ async def run_orchestration(self, user_id: str, input_task) -> None:
383396
task_text = getattr(input_task, "description", str(input_task))
384397
self.logger.debug("Task: %s", task_text)
385398

386-
# Track how many times each agent is called (for debugging duplicate calls)
387-
agent_call_counts: dict = {}
388-
389399
try:
390400
# Execute workflow using run_stream with task as positional parameter
391401
# The execution settings are configured in the manager/client
392402
final_output: str | None = None
393403

394404
self.logger.info("Starting workflow execution...")
395-
print(f"[ORCHESTRATOR] 🚀 Starting workflow with max_rounds={orchestration_config.max_rounds}", flush=True)
396405
last_message_id: str | None = None
397406
async for event in workflow.run_stream(task_text):
398407
try:
@@ -437,20 +446,11 @@ async def run_orchestration(self, user_id: str, input_task) -> None:
437446

438447
# Handle group chat request sent
439448
elif isinstance(event, GroupChatRequestSentEvent):
440-
agent_name = event.participant_name
441-
agent_call_counts[agent_name] = agent_call_counts.get(agent_name, 0) + 1
442-
call_num = agent_call_counts[agent_name]
443-
444449
self.logger.info(
445-
"[REQUEST SENT (round %d)] to agent: %s (call #%d)",
450+
"[REQUEST SENT (round %d)] to agent: %s",
446451
event.round_index,
447-
agent_name,
448-
call_num
452+
event.participant_name
449453
)
450-
print(f"[ORCHESTRATOR] 📤 REQUEST SENT round={event.round_index} to agent={agent_name} (call #{call_num})", flush=True)
451-
452-
if call_num > 1:
453-
print(f"[ORCHESTRATOR] ⚠️ WARNING: Agent '{agent_name}' called {call_num} times!", flush=True)
454454

455455
# Handle group chat response received - THIS IS WHERE AGENT RESPONSES COME
456456
elif isinstance(event, GroupChatResponseReceivedEvent):
@@ -511,13 +511,6 @@ async def run_orchestration(self, user_id: str, input_task) -> None:
511511
# Extract final result
512512
final_text = final_output if final_output else ""
513513

514-
# Log agent call summary
515-
print(f"\n[ORCHESTRATOR] 📊 AGENT CALL SUMMARY:", flush=True)
516-
for agent_name, count in agent_call_counts.items():
517-
status = "✅" if count == 1 else "⚠️ DUPLICATE"
518-
print(f" {status} {agent_name}: called {count} time(s)", flush=True)
519-
self.logger.info("Agent call counts: %s", agent_call_counts)
520-
521514
# Log results
522515
self.logger.info("\nAgent responses:")
523516
self.logger.info(

0 commit comments

Comments
 (0)