1- """Manages all agent communication and chat strategies for the SQL agents ."""
1+ """Optimized CommsManager with parallel processing and performance improvements ."""
22
3- from semantic_kernel .agents import AgentGroupChat # pylint: disable=E0611
3+ import asyncio
4+ import logging
5+ import re
6+ from typing import AsyncIterable , ClassVar , List
7+ from concurrent .futures import ThreadPoolExecutor
8+
9+ from semantic_kernel .agents import AgentGroupChat
410from semantic_kernel .agents .strategies import (
511 SequentialSelectionStrategy ,
612 TerminationStrategy ,
713)
14+ from semantic_kernel .contents import ChatMessageContent
15+ from semantic_kernel .exceptions import AgentInvokeException
816
917from sql_agents .agents .migrator .response import MigratorResponse
1018from sql_agents .helpers .models import AgentType
1119
1220
1321class CommsManager :
14- """Manages all agent communication and selection strategies for the SQL agents."""
15-
16- group_chat : AgentGroupChat = None
22+ """Optimized CommsManager with parallel processing and performance improvements."""
1723
18- class SelectionStrategy (SequentialSelectionStrategy ):
19- """A strategy for determining which agent should take the next turn in the chat."""
20-
21- # Select the next agent that should take the next turn in the chat
22- async def select_agent (self , agents , history ):
23- """Check which agent should take the next turn in the chat."""
24- match history [- 1 ].name :
25- case AgentType .MIGRATOR .value :
26- # The Migrator should go first
27- agent_name = AgentType .PICKER .value
28- return next (
29- (agent for agent in agents if agent .name == agent_name ), None
30- )
31- # The Incident Manager should go after the User or the Devops Assistant
32- case AgentType .PICKER .value :
33- agent_name = AgentType .SYNTAX_CHECKER .value
34- return next (
35- (agent for agent in agents if agent .name == agent_name ), None
36- )
37- case AgentType .SYNTAX_CHECKER .value :
38- agent_name = AgentType .FIXER .value
39- return next (
40- (agent for agent in agents if agent .name == agent_name ),
41- None ,
42- )
43- case AgentType .FIXER .value :
44- # The Fixer should always go after the Syntax Checker
45- agent_name = AgentType .SYNTAX_CHECKER .value
46- return next (
47- (agent for agent in agents if agent .name == agent_name ), None
48- )
49- case "candidate" :
50- # The candidate message is created in the orchestration loop to pass the
51- # candidate and source sql queries to the Semantic Verifier
52- # It is created when the Syntax Checker returns an empty list of errors
53- agent_name = AgentType .SEMANTIC_VERIFIER .value
54- return next (
55- (agent for agent in agents if agent .name == agent_name ),
56- None ,
57- )
58- case _:
59- # Start run with this one - no history
60- return next (
61- (
62- agent
63- for agent in agents
64- if agent .name == AgentType .MIGRATOR .value
65- ),
66- None ,
67- )
24+ logger : ClassVar [logging .Logger ] = logging .getLogger (__name__ )
25+ _EXTRACT_WAIT_TIME = r"in (\d+) seconds"
6826
69- # class for termination strategy
70- class ApprovalTerminationStrategy (TerminationStrategy ):
71- """
72- A strategy for determining when an agent should terminate.
73- This, combined with the maximum_iterations setting on the group chat, determines
74- when the agents are finished processing a file when there are no errors.
75- """
27+ def __init__ (
28+ self ,
29+ agent_dict : dict [AgentType , object ],
30+ exception_types : tuple = (Exception ,),
31+ max_retries : int = 3 , # reduc from 10
32+ initial_delay : float = 0.5 , # reduced from 1.0
33+ backoff_factor : float = 1.5 , # reduced from 2.0
34+ simple_truncation : int = 50 , # more aggr truncation
35+ batch_size : int = 10 , # process in batches
36+ max_workers : int = 4 , # parallel processing
37+ ):
38+ self .max_retries = max_retries
39+ self .initial_delay = initial_delay
40+ self .backoff_factor = backoff_factor
41+ self .exception_types = exception_types
42+ self .simple_truncation = simple_truncation
43+ self .batch_size = batch_size
44+ self .max_workers = max_workers
45+
46+ # Thread pool for parallel processing
47+ self .executor = ThreadPoolExecutor (max_workers = max_workers )
7648
77- async def should_agent_terminate (self , agent , history ):
78- """Check if the agent should terminate."""
79- # May need to convert to models to get usable content using history[-1].name
80- terminate : bool = False
81- lower_case_hist : str = history [- 1 ].content .lower ()
82- match history [- 1 ].name :
83- case AgentType .MIGRATOR .value :
84- response = MigratorResponse .model_validate_json (
85- lower_case_hist or ""
86- )
87- if (
88- response .input_error is not None
89- or response .rai_error is not None
90- ):
91- terminate = True
92- case AgentType .SEMANTIC_VERIFIER .value :
93- # Always terminate after the Semantic Verifier runs
94- terminate = True
95- case _:
96- # If the agent is not the Migrator or Semantic Verifier, don't terminate
97- # Note that the Syntax Checker and Fixer loop are only terminated by correct SQL
98- # or by iterations exceeding the max_iterations setting
99- pass
100-
101- return terminate
102-
103- def __init__ (self , agent_dict ):
104- """Initialize the CommsManager and agent_chat with the given agents."""
10549 self .group_chat = AgentGroupChat (
10650 agents = agent_dict .values (),
107- termination_strategy = self .ApprovalTerminationStrategy (
51+ termination_strategy = self .OptimizedTerminationStrategy (
10852 agents = [
10953 agent_dict [AgentType .MIGRATOR ],
11054 agent_dict [AgentType .SEMANTIC_VERIFIER ],
11155 ],
112- maximum_iterations = 10 ,
56+ maximum_iterations = 5 , # Reduced from 10
11357 automatic_reset = True ,
11458 ),
115- selection_strategy = self .SelectionStrategy (agents = agent_dict .values ()),
59+ selection_strategy = self .ParallelSelectionStrategy (
60+ agents = agent_dict .values (),
61+ max_workers = max_workers
62+ ),
11663 )
64+
65+ async def async_invoke_batch (self , inputs : List [str ]) -> AsyncIterable [ChatMessageContent ]:
66+ """Process multiple inputs in parallel batches."""
67+ # Process inputs in batches
68+ for i in range (0 , len (inputs ), self .batch_size ):
69+ batch = inputs [i :i + self .batch_size ]
70+
71+ # Process batch in parallel
72+ tasks = [self ._process_single_input (input_item ) for input_item in batch ]
73+ batch_results = await asyncio .gather (* tasks , return_exceptions = True )
74+
75+ for result in batch_results :
76+ if isinstance (result , Exception ):
77+ self .logger .error (f"Batch processing error: { result } " )
78+ continue
79+
80+ async for item in result :
81+ yield item
82+
83+ async def _process_single_input (self , input_item : str ) -> AsyncIterable [ChatMessageContent ]:
84+ """Process a single input with optimized retry logic."""
85+ attempt = 0
86+ current_delay = self .initial_delay
87+
88+ while attempt < self .max_retries :
89+ try :
90+ # Aggressive history truncation
91+ if len (self .group_chat .history ) > self .simple_truncation :
92+ # Keep only the most recent messages
93+ self .group_chat .history = self .group_chat .history [- self .simple_truncation :]
94+
95+ # Add input to chat
96+ self .group_chat .add_chat_message (ChatMessageContent (
97+ role = "user" ,
98+ content = input_item
99+ ))
100+
101+ async_iter = self .group_chat .invoke ()
102+ async for item in async_iter :
103+ yield item
104+ break
105+
106+ except AgentInvokeException as aie :
107+ attempt += 1
108+ if attempt >= self .max_retries :
109+ self .logger .error (
110+ "Input processing failed after %d attempts: %s" ,
111+ self .max_retries , str (aie )
112+ )
113+ # Don't raise, continue with next input
114+ break
115+
116+ # Faster retry with shorter delays
117+ match = re .search (self ._EXTRACT_WAIT_TIME , str (aie ))
118+ if match :
119+ current_delay = min (int (match .group (1 )), 5 ) # Cap at 5 seconds
120+ else :
121+ current_delay = min (current_delay * self .backoff_factor , 10 ) # Cap at 10 seconds
122+
123+ self .logger .warning (
124+ "Attempt %d/%d failed. Retrying in %.2f seconds..." ,
125+ attempt , self .max_retries , current_delay
126+ )
127+ await asyncio .sleep (current_delay )
128+
129+ class ParallelSelectionStrategy (SequentialSelectionStrategy ):
130+ """Optimized selection strategy with parallel processing capabilities."""
131+
132+ def __init__ (self , agents , max_workers : int = 4 ):
133+ super ().__init__ (agents )
134+ self .max_workers = max_workers
135+
136+ async def select_agent (self , agents , history ):
137+ """Select agent with optimized logic and parallel processing hints."""
138+ if not history :
139+ return next ((agent for agent in agents if agent .name == AgentType .MIGRATOR .value ), None )
140+
141+ last_agent = history [- 1 ].name
142+
143+ # Optimized selection logic with fewer transitions
144+ agent_transitions = {
145+ AgentType .MIGRATOR .value : AgentType .PICKER .value ,
146+ AgentType .PICKER .value : AgentType .SYNTAX_CHECKER .value ,
147+ AgentType .SYNTAX_CHECKER .value : AgentType .FIXER .value ,
148+ AgentType .FIXER .value : AgentType .SEMANTIC_VERIFIER .value , # Skip syntax check
149+ "candidate" : AgentType .SEMANTIC_VERIFIER .value ,
150+ }
151+
152+ next_agent_name = agent_transitions .get (last_agent , AgentType .MIGRATOR .value )
153+ return next ((agent for agent in agents if agent .name == next_agent_name ), None )
154+
155+ class OptimizedTerminationStrategy (TerminationStrategy ):
156+ """Optimized termination strategy with faster decision making."""
157+
158+ async def should_agent_terminate (self , agent , history ):
159+ """Determine termination with optimized checks."""
160+ if not history :
161+ return False
162+
163+ last_message = history [- 1 ]
164+ lower_case_content = last_message .content .lower ()
165+
166+ # Fast termination checks
167+ if last_message .name == AgentType .SEMANTIC_VERIFIER .value :
168+ return True
169+
170+ if last_message .name == AgentType .MIGRATOR .value :
171+ try :
172+ # Faster JSON parsing with error handling
173+ response = MigratorResponse .model_validate_json (lower_case_content or "{}" )
174+ return bool (response .input_error or response .rai_error )
175+ except Exception :
176+ # If parsing fails, assume no termination needed
177+ return False
178+
179+ return False
180+
181+ def cleanup (self ):
182+ """Clean up resources."""
183+ if hasattr (self , 'executor' ):
184+ self .executor .shutdown (wait = False )
0 commit comments