Skip to content

Commit 2ad110b

Browse files
authored
Merge pull request #2130 from FedML-AI/alexleung/dev_branch_latest
Alexleung/dev branch latest
2 parents 84d6156 + 1162f6c commit 2ad110b

2 files changed

Lines changed: 21 additions & 19 deletions

File tree

python/fedml/computing/scheduler/master/base_master_protocol_manager.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,23 +264,26 @@ def callback_stop_train(self, topic, payload, use_payload=None):
264264
run_id = request_json.get("runId", None)
265265
run_id = request_json.get("id", None) if run_id is None else run_id
266266
run_id_str = str(run_id)
267+
edge_ids = request_json.get("edgeids", None)
267268
server_id = request_json.get("serverId", None)
268269
if server_id is None:
269270
server_id = request_json.get("server_id", None)
270-
edge_ids = request_json.get("edgeids", None)
271-
272-
# Stop the job runner
273-
self._get_job_runner_manager().stop_job_runner(
274-
run_id, args=self.args, server_id=server_id, request_json=request_json,
275-
run_as_cloud_agent=self.run_as_cloud_agent)
271+
server_agent_id = server_id
276272

277273
# Cleanup the cached object
278274
if self.running_request_json.get(run_id_str, None) is not None:
279275
self.running_request_json.pop(run_id_str)
280276

277+
# If it is the cloud agent, then forward the stopping request to the corresponding cloud server.
278+
if self.run_as_cloud_agent:
279+
server_agent_id = self.edge_id
280+
topic_stop_train_to_cloud_server = f"mlops/flserver_agent_{server_id}/stop_train"
281+
self.message_center.send_message(topic_stop_train_to_cloud_server, payload)
282+
return
283+
281284
# Reset all edge status and server status
282285
for iter_edge_id in edge_ids:
283-
self.generate_status_report(run_id, iter_edge_id, server_agent_id=server_id).\
286+
self.generate_status_report(run_id, iter_edge_id, server_agent_id=server_agent_id).\
284287
report_client_id_status(iter_edge_id, GeneralConstants.MSG_MLOPS_SERVER_STATUS_KILLED,
285288
run_id=run_id, server_id=server_id)
286289

python/fedml/computing/scheduler/scheduler_core/status_manager_protocols.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, run_id=None, edge_id=None, server_id=None,
2323
self.edge_id = edge_id
2424
self.server_id = server_id
2525
self.edge_id_list = edge_id_list
26-
self.client_agent_active_list = dict()
26+
self.edge_status_dict = None
2727
self.running_scheduler_contract = running_scheduler_contract if running_scheduler_contract is not None else dict()
2828
self.message_reporter = MLOpsMetrics()
2929
self.message_reporter.set_messenger(message_center)
@@ -163,6 +163,8 @@ def status_center_process_master_status(self, topic, payload):
163163
status = request_json["status"]
164164
edge_id = request_json["edge_id"]
165165
server_id = request_json.get("server_id", None)
166+
if server_id is None or str(server_id) == "0":
167+
server_id = self.server_id
166168
run_id_str = str(run_id)
167169

168170
# Process the job status
@@ -185,8 +187,7 @@ def process_job_status_consensus(self, run_id, master_id, status):
185187
status = self.get_entire_job_status()
186188

187189
# Set the device status based on the job status
188-
edge_id_status_dict = self.client_agent_active_list.get(f"{run_id}", {})
189-
for edge_id_item, edge_status_item in edge_id_status_dict.items():
190+
for edge_id_item, edge_status_item in self.edge_status_dict.items():
190191
if edge_id_item == "server":
191192
continue
192193

@@ -233,31 +234,29 @@ def status_center_process_slave_status(self, topic, payload):
233234
init_edge_id_list = payload_json.get("init_all_edge_id_list", None)
234235
init_server_id = payload_json.get("init_server_id", None)
235236

236-
active_item_dict = self.client_agent_active_list.get(f"{run_id}", None)
237-
if active_item_dict is None:
238-
self.client_agent_active_list[f"{run_id}"] = dict()
237+
if self.edge_status_dict is None:
238+
self.edge_status_dict = dict()
239239

240240
if init_edge_id_list is not None:
241-
self.client_agent_active_list[f"{run_id}"][f"server"] = init_server_id
241+
self.edge_status_dict[f"server"] = init_server_id
242242
for edge_id_item in init_edge_id_list:
243-
self.client_agent_active_list[f"{run_id}"][f"{edge_id_item}"] = \
243+
self.edge_status_dict[f"{edge_id_item}"] = \
244244
ClientConstants.MSG_MLOPS_CLIENT_STATUS_IDLE
245245

246246
if run_id is not None and edge_id is not None:
247-
self.client_agent_active_list[f"{run_id}"][f"{edge_id}"] = status
247+
self.edge_status_dict[f"{edge_id}"] = status
248248

249249
self.process_device_status(run_id, edge_id, status)
250250

251251
def process_device_status(self, run_id, edge_id, status):
252252
number_of_failed_edges = 0
253253
number_of_finished_edges = 0
254254
number_of_killed_edges = 0
255-
edge_id_status_dict = self.client_agent_active_list.get(f"{run_id}", {})
256-
server_id = edge_id_status_dict.get("server", 0)
255+
server_id = self.edge_status_dict.get("server", 0)
257256
enable_fault_tolerance, fault_tolerance_rate = self.parse_fault_tolerance_params(run_id)
258257
running_edges_list = list()
259258
edge_nums = 0
260-
for edge_id_item, status_item in edge_id_status_dict.items():
259+
for edge_id_item, status_item in self.edge_status_dict.items():
261260
if edge_id_item == "server":
262261
continue
263262

0 commit comments

Comments
 (0)