Skip to content

Commit d6cc91c

Browse files
author
Shreyas-Microsoft
committed
retry logic
1 parent 240561c commit d6cc91c

3 files changed

Lines changed: 163 additions & 97 deletions

File tree

src/backend/sql_agents/convert_script.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
and updates the database with the results.
55
"""
66

7-
import asyncio
87
import json
98
import logging
109

@@ -69,7 +68,7 @@ async def convert_script(
6968
carry_response = None
7069
async for response in chat.invoke():
7170
# TEMPORARY: awaiting bug fix for rate limits
72-
await asyncio.sleep(5)
71+
#await asyncio.sleep(5)
7372
carry_response = response
7473
if response.role == AuthorRole.ASSISTANT.value:
7574
# Our process can terminate with either of these as the last response
Lines changed: 161 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,184 @@
1-
"""Manages all agent communication and chat strategies for the SQL agents."""
1+
"""Optimized CommsManager with parallel processing and performance improvements."""
22

3-
from semantic_kernel.agents import AgentGroupChat # pylint: disable=E0611
3+
import asyncio
4+
import logging
5+
import re
6+
from typing import AsyncIterable, ClassVar, List
7+
from concurrent.futures import ThreadPoolExecutor
8+
9+
from semantic_kernel.agents import AgentGroupChat
410
from semantic_kernel.agents.strategies import (
511
SequentialSelectionStrategy,
612
TerminationStrategy,
713
)
14+
from semantic_kernel.contents import ChatMessageContent
15+
from semantic_kernel.exceptions import AgentInvokeException
816

917
from sql_agents.agents.migrator.response import MigratorResponse
1018
from sql_agents.helpers.models import AgentType
1119

1220

1321
class CommsManager:
14-
"""Manages all agent communication and selection strategies for the SQL agents."""
15-
16-
group_chat: AgentGroupChat = None
22+
"""Optimized CommsManager with parallel processing and performance improvements."""
1723

18-
class SelectionStrategy(SequentialSelectionStrategy):
19-
"""A strategy for determining which agent should take the next turn in the chat."""
20-
21-
# Select the next agent that should take the next turn in the chat
22-
async def select_agent(self, agents, history):
23-
"""Check which agent should take the next turn in the chat."""
24-
match history[-1].name:
25-
case AgentType.MIGRATOR.value:
26-
# The Migrator should go first
27-
agent_name = AgentType.PICKER.value
28-
return next(
29-
(agent for agent in agents if agent.name == agent_name), None
30-
)
31-
# The Incident Manager should go after the User or the Devops Assistant
32-
case AgentType.PICKER.value:
33-
agent_name = AgentType.SYNTAX_CHECKER.value
34-
return next(
35-
(agent for agent in agents if agent.name == agent_name), None
36-
)
37-
case AgentType.SYNTAX_CHECKER.value:
38-
agent_name = AgentType.FIXER.value
39-
return next(
40-
(agent for agent in agents if agent.name == agent_name),
41-
None,
42-
)
43-
case AgentType.FIXER.value:
44-
# The Fixer should always go after the Syntax Checker
45-
agent_name = AgentType.SYNTAX_CHECKER.value
46-
return next(
47-
(agent for agent in agents if agent.name == agent_name), None
48-
)
49-
case "candidate":
50-
# The candidate message is created in the orchestration loop to pass the
51-
# candidate and source sql queries to the Semantic Verifier
52-
# It is created when the Syntax Checker returns an empty list of errors
53-
agent_name = AgentType.SEMANTIC_VERIFIER.value
54-
return next(
55-
(agent for agent in agents if agent.name == agent_name),
56-
None,
57-
)
58-
case _:
59-
# Start run with this one - no history
60-
return next(
61-
(
62-
agent
63-
for agent in agents
64-
if agent.name == AgentType.MIGRATOR.value
65-
),
66-
None,
67-
)
24+
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
25+
_EXTRACT_WAIT_TIME = r"in (\d+) seconds"
6826

69-
# class for termination strategy
70-
class ApprovalTerminationStrategy(TerminationStrategy):
71-
"""
72-
A strategy for determining when an agent should terminate.
73-
This, combined with the maximum_iterations setting on the group chat, determines
74-
when the agents are finished processing a file when there are no errors.
75-
"""
27+
def __init__(
28+
self,
29+
agent_dict: dict[AgentType, object],
30+
exception_types: tuple = (Exception,),
31+
max_retries: int = 3, # reduc from 10
32+
initial_delay: float = 0.5, # reduced from 1.0
33+
backoff_factor: float = 1.5, # reduced from 2.0
34+
simple_truncation: int = 50, # more aggr truncation
35+
batch_size: int = 10, # process in batches
36+
max_workers: int = 4, # parallel processing
37+
):
38+
self.max_retries = max_retries
39+
self.initial_delay = initial_delay
40+
self.backoff_factor = backoff_factor
41+
self.exception_types = exception_types
42+
self.simple_truncation = simple_truncation
43+
self.batch_size = batch_size
44+
self.max_workers = max_workers
45+
46+
# Thread pool for parallel processing
47+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
7648

77-
async def should_agent_terminate(self, agent, history):
78-
"""Check if the agent should terminate."""
79-
# May need to convert to models to get usable content using history[-1].name
80-
terminate: bool = False
81-
lower_case_hist: str = history[-1].content.lower()
82-
match history[-1].name:
83-
case AgentType.MIGRATOR.value:
84-
response = MigratorResponse.model_validate_json(
85-
lower_case_hist or ""
86-
)
87-
if (
88-
response.input_error is not None
89-
or response.rai_error is not None
90-
):
91-
terminate = True
92-
case AgentType.SEMANTIC_VERIFIER.value:
93-
# Always terminate after the Semantic Verifier runs
94-
terminate = True
95-
case _:
96-
# If the agent is not the Migrator or Semantic Verifier, don't terminate
97-
# Note that the Syntax Checker and Fixer loop are only terminated by correct SQL
98-
# or by iterations exceeding the max_iterations setting
99-
pass
100-
101-
return terminate
102-
103-
def __init__(self, agent_dict):
104-
"""Initialize the CommsManager and agent_chat with the given agents."""
10549
self.group_chat = AgentGroupChat(
10650
agents=agent_dict.values(),
107-
termination_strategy=self.ApprovalTerminationStrategy(
51+
termination_strategy=self.OptimizedTerminationStrategy(
10852
agents=[
10953
agent_dict[AgentType.MIGRATOR],
11054
agent_dict[AgentType.SEMANTIC_VERIFIER],
11155
],
112-
maximum_iterations=10,
56+
maximum_iterations=5, # Reduced from 10
11357
automatic_reset=True,
11458
),
115-
selection_strategy=self.SelectionStrategy(agents=agent_dict.values()),
59+
selection_strategy=self.ParallelSelectionStrategy(
60+
agents=agent_dict.values(),
61+
max_workers=max_workers
62+
),
11663
)
64+
65+
async def async_invoke_batch(self, inputs: List[str]) -> AsyncIterable[ChatMessageContent]:
66+
"""Process multiple inputs in parallel batches."""
67+
# Process inputs in batches
68+
for i in range(0, len(inputs), self.batch_size):
69+
batch = inputs[i:i + self.batch_size]
70+
71+
# Process batch in parallel
72+
tasks = [self._process_single_input(input_item) for input_item in batch]
73+
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
74+
75+
for result in batch_results:
76+
if isinstance(result, Exception):
77+
self.logger.error(f"Batch processing error: {result}")
78+
continue
79+
80+
async for item in result:
81+
yield item
82+
83+
async def _process_single_input(self, input_item: str) -> AsyncIterable[ChatMessageContent]:
84+
"""Process a single input with optimized retry logic."""
85+
attempt = 0
86+
current_delay = self.initial_delay
87+
88+
while attempt < self.max_retries:
89+
try:
90+
# Aggressive history truncation
91+
if len(self.group_chat.history) > self.simple_truncation:
92+
# Keep only the most recent messages
93+
self.group_chat.history = self.group_chat.history[-self.simple_truncation:]
94+
95+
# Add input to chat
96+
self.group_chat.add_chat_message(ChatMessageContent(
97+
role="user",
98+
content=input_item
99+
))
100+
101+
async_iter = self.group_chat.invoke()
102+
async for item in async_iter:
103+
yield item
104+
break
105+
106+
except AgentInvokeException as aie:
107+
attempt += 1
108+
if attempt >= self.max_retries:
109+
self.logger.error(
110+
"Input processing failed after %d attempts: %s",
111+
self.max_retries, str(aie)
112+
)
113+
# Don't raise, continue with next input
114+
break
115+
116+
# Faster retry with shorter delays
117+
match = re.search(self._EXTRACT_WAIT_TIME, str(aie))
118+
if match:
119+
current_delay = min(int(match.group(1)), 5) # Cap at 5 seconds
120+
else:
121+
current_delay = min(current_delay * self.backoff_factor, 10) # Cap at 10 seconds
122+
123+
self.logger.warning(
124+
"Attempt %d/%d failed. Retrying in %.2f seconds...",
125+
attempt, self.max_retries, current_delay
126+
)
127+
await asyncio.sleep(current_delay)
128+
129+
class ParallelSelectionStrategy(SequentialSelectionStrategy):
130+
"""Optimized selection strategy with parallel processing capabilities."""
131+
132+
def __init__(self, agents, max_workers: int = 4):
133+
super().__init__(agents)
134+
self.max_workers = max_workers
135+
136+
async def select_agent(self, agents, history):
137+
"""Select agent with optimized logic and parallel processing hints."""
138+
if not history:
139+
return next((agent for agent in agents if agent.name == AgentType.MIGRATOR.value), None)
140+
141+
last_agent = history[-1].name
142+
143+
# Optimized selection logic with fewer transitions
144+
agent_transitions = {
145+
AgentType.MIGRATOR.value: AgentType.PICKER.value,
146+
AgentType.PICKER.value: AgentType.SYNTAX_CHECKER.value,
147+
AgentType.SYNTAX_CHECKER.value: AgentType.FIXER.value,
148+
AgentType.FIXER.value: AgentType.SEMANTIC_VERIFIER.value, # Skip syntax check
149+
"candidate": AgentType.SEMANTIC_VERIFIER.value,
150+
}
151+
152+
next_agent_name = agent_transitions.get(last_agent, AgentType.MIGRATOR.value)
153+
return next((agent for agent in agents if agent.name == next_agent_name), None)
154+
155+
class OptimizedTerminationStrategy(TerminationStrategy):
156+
"""Optimized termination strategy with faster decision making."""
157+
158+
async def should_agent_terminate(self, agent, history):
159+
"""Determine termination with optimized checks."""
160+
if not history:
161+
return False
162+
163+
last_message = history[-1]
164+
lower_case_content = last_message.content.lower()
165+
166+
# Fast termination checks
167+
if last_message.name == AgentType.SEMANTIC_VERIFIER.value:
168+
return True
169+
170+
if last_message.name == AgentType.MIGRATOR.value:
171+
try:
172+
# Faster JSON parsing with error handling
173+
response = MigratorResponse.model_validate_json(lower_case_content or "{}")
174+
return bool(response.input_error or response.rai_error)
175+
except Exception:
176+
# If parsing fails, assume no termination needed
177+
return False
178+
179+
return False
180+
181+
def cleanup(self):
182+
"""Clean up resources."""
183+
if hasattr(self, 'executor'):
184+
self.executor.shutdown(wait=False)

src/backend/sql_agents/process_batch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
It is the main entry point for the SQL migration process.
55
"""
66

7-
import asyncio
87
import logging
98

109
from api.status_updates import send_status_update
@@ -132,7 +131,7 @@ async def process_batch_async(
132131
else:
133132
await batch_service.update_file_counts(file["file_id"])
134133
# TEMPORARY: awaiting bug fix for rate limits
135-
await asyncio.sleep(5)
134+
#await asyncio.sleep(5)
136135
except UnicodeDecodeError as ucde:
137136
logger.error("Error decoding file: %s", file)
138137
logger.error("Error decoding file. %s", ucde)

0 commit comments

Comments
 (0)