@@ -33,6 +33,7 @@ def __init__(self, args, agent_config=None):
3333 self .agent_config = agent_config
3434 self .topic_start_train = None
3535 self .topic_stop_train = None
36+ self .topic_complete_job = None
3637 self .topic_report_status = None
3738 self .topic_ota_msg = None
3839 self .topic_response_device_info = None
@@ -44,7 +45,7 @@ def __init__(self, args, agent_config=None):
4445 self .run_as_cloud_server = False
4546 self .run_as_edge_server_and_agent = False
4647 self .run_as_cloud_server_and_agent = False
47- self .enable_simulation_cloud_agent = True
48+ self .enable_simulation_cloud_agent = False
4849 self .use_local_process_as_cloud_server = False
4950 self .ota_upgrade = FedMLOtaUpgrade (edge_id = args .edge_id )
5051 self .running_request_json = dict ()
@@ -61,6 +62,9 @@ def generate_topics(self):
6162 # The topi for stopping training
6263 self .topic_stop_train = "mlops/flserver_agent_" + str (self .edge_id ) + "/stop_train"
6364
65+ # The topic for completing job
66+ self .topic_complete_job = GeneralConstants .get_topic_complete_job (self .edge_id )
67+
6468 # The topic for reporting current device status.
6569 self .topic_report_status = "mlops/report_device_status"
6670
@@ -89,6 +93,7 @@ def generate_topics(self):
8993 self .subscribed_topics .clear ()
9094 self .add_subscribe_topic (self .topic_start_train )
9195 self .add_subscribe_topic (self .topic_stop_train )
96+ self .add_subscribe_topic (self .topic_complete_job )
9297 self .add_subscribe_topic (self .topic_report_status )
9398 self .add_subscribe_topic (self .topic_ota_msg )
9499 self .add_subscribe_topic (self .topic_response_device_info )
@@ -103,6 +108,7 @@ def add_protocol_handler(self):
103108 # Add the message listeners for all topics
104109 self .add_message_listener (self .topic_start_train , self .callback_start_train )
105110 self .add_message_listener (self .topic_stop_train , self .callback_stop_train )
111+ self .add_message_listener (self .topic_complete_job , self .callback_complete_job )
106112 self .add_message_listener (self .topic_ota_msg , FedMLBaseMasterProtocolManager .callback_server_ota_msg )
107113 self .add_message_listener (self .topic_report_status , self .callback_report_current_status )
108114 self .add_message_listener (self .topic_response_device_info , self .callback_response_device_info )
@@ -140,12 +146,6 @@ def callback_start_train(self, topic=None, payload=None):
140146 except Exception :
141147 pass
142148
143- # Parse the message when running in the cloud server mode.
144- if self .run_as_cloud_server :
145- message_bytes = payload .encode ("ascii" )
146- base64_bytes = base64 .b64decode (message_bytes )
147- payload = base64_bytes .decode ("ascii" )
148-
149149 # Parse the parameters
150150 # [NOTES] Example Request JSON:
151151 # https://fedml-inc.larksuite.com/wiki/ScnIwUif9iupbjkYS0LuBrd6sod#WjbEdhYrvogmlGxKTOGu98C6sSb
@@ -264,6 +264,9 @@ def callback_stop_train(self, topic, payload, use_payload=None):
264264 run_id = request_json .get ("runId" , None )
265265 run_id = request_json .get ("id" , None ) if run_id is None else run_id
266266 run_id_str = str (run_id )
267+ server_id = request_json .get ("serverId" , None )
268+ if server_id is None :
269+ server_id = request_json .get ("server_id" , None )
267270
268271 # Broadcast the job status to all edges
269272 self .rebuild_status_center (self .get_status_queue ())
@@ -274,7 +277,24 @@ def callback_stop_train(self, topic, payload, use_payload=None):
274277 self .running_request_json .pop (run_id_str )
275278
276279 # Stop the job runner
277- self ._get_job_runner_manager ().stop_job_runner (run_id )
280+ self ._get_job_runner_manager ().stop_job_runner (
281+ run_id , args = self .args , server_id = server_id , request_json = request_json ,
282+ run_as_cloud_agent = self .run_as_cloud_agent )
283+
284+ def callback_complete_job (self , topic , payload ):
285+ # Parse the parameters.
286+ request_json = json .loads (payload )
287+ run_id = request_json .get ("runId" , None )
288+ run_id = request_json .get ("id" , None ) if run_id is None else run_id
289+ run_id_str = str (run_id )
290+ server_id = request_json .get ("serverId" , None )
291+ if server_id is None :
292+ server_id = request_json .get ("server_id" , None )
293+
294+ self ._process_job_complete_status (run_id , server_id , request_json )
295+
296+ def _process_job_complete_status (self , run_id , server_id , complete_payload ):
297+ pass
278298
279299 def callback_run_logs (self , topic , payload ):
280300 run_id = str (topic ).split ('/' )[- 1 ]
@@ -498,6 +518,11 @@ def send_training_stop_request_to_specific_edge(self, edge_id, payload):
498518 logging .info ("stop_train: send topic " + topic_stop_train )
499519 self .message_center .send_message (topic_stop_train , payload )
500520
521+ def send_training_stop_request_to_cloud_server (self , edge_id , payload ):
522+ topic_stop_train = "mlops/flserver_agent_" + str (edge_id ) + "/stop_train"
523+ logging .info ("stop_train: send topic " + topic_stop_train )
524+ self .message_center .send_message (topic_stop_train , payload )
525+
501526 def send_status_check_msg (self , run_id , edge_id , server_id , context = None ):
502527 topic_status_check = f"server/client/request_device_info/{ edge_id } "
503528 payload = {"server_id" : server_id , "run_id" : run_id }
0 commit comments