Skip to content

Commit 8e03183

Browse files
Improving pending requests counter robustness.
1 parent 1cc1552 commit 8e03183

2 files changed

Lines changed: 88 additions & 78 deletions

File tree

python/fedml/computing/scheduler/model_scheduler/device_model_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,5 +985,9 @@ def update_pending_requests_counter(self, increase=False, decrease=False) -> int
985985
if increase:
986986
self.redis_connection.incr(self.FEDML_PENDING_REQUESTS_COUNTER)
987987
if decrease:
988-
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
988+
# Making sure the counter never becomes negative!
989+
if self.get_pending_requests_counter() < 0:
990+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
991+
else:
992+
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
989993
return self.get_pending_requests_counter()

python/fedml/computing/scheduler/model_scheduler/device_model_inference.py

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -173,92 +173,98 @@ async def _predict(
173173
header=None
174174
) -> Union[MutableMapping[str, Any], Response, StreamingResponse]:
175175

176+
# Always increase the pending requests counter on a new incoming request.
176177
FEDML_MODEL_CACHE.update_pending_requests_counter(increase=True)
177178
inference_response = {}
178179

179-
in_end_point_id = end_point_id
180-
in_end_point_name = input_json.get("end_point_name", None)
181-
in_model_name = input_json.get("model_name", None)
182-
in_model_version = input_json.get("model_version", None)
183-
in_end_point_token = input_json.get("token", None)
184-
in_return_type = "default"
185-
if header is not None:
186-
in_return_type = header.get("Accept", "default")
187-
188-
if in_model_version is None:
189-
in_model_version = "*" # * | latest | specific version
190-
191-
start_time = time.time_ns()
192-
193-
# Allow missing end_point_name and model_name in the input parameters.
194-
if in_model_name is None or in_end_point_name is None:
195-
ret_endpoint_name, ret_model_name = retrieve_info_by_endpoint_id(in_end_point_id, in_end_point_name)
196-
if in_model_name is None:
197-
in_model_name = ret_model_name
198-
if in_end_point_name is None:
199-
in_end_point_name = ret_endpoint_name
200-
201-
# Authenticate request token
202-
inference_response = {}
203-
if auth_request_token(in_end_point_id, in_end_point_name, in_model_name, in_end_point_token):
204-
# Check the endpoint is activated
205-
if not is_endpoint_activated(in_end_point_id):
206-
inference_response = {"error": True, "message": "endpoint is not activated."}
180+
try:
181+
in_end_point_id = end_point_id
182+
in_end_point_name = input_json.get("end_point_name", None)
183+
in_model_name = input_json.get("model_name", None)
184+
in_model_version = input_json.get("model_version", None)
185+
in_end_point_token = input_json.get("token", None)
186+
in_return_type = "default"
187+
if header is not None:
188+
in_return_type = header.get("Accept", "default")
189+
190+
if in_model_version is None:
191+
in_model_version = "*" # * | latest | specific version
192+
193+
start_time = time.time_ns()
194+
195+
# Allow missing end_point_name and model_name in the input parameters.
196+
if in_model_name is None or in_end_point_name is None:
197+
ret_endpoint_name, ret_model_name = retrieve_info_by_endpoint_id(in_end_point_id, in_end_point_name)
198+
if in_model_name is None:
199+
in_model_name = ret_model_name
200+
if in_end_point_name is None:
201+
in_end_point_name = ret_endpoint_name
202+
203+
# Authenticate request token
204+
if auth_request_token(in_end_point_id, in_end_point_name, in_model_name, in_end_point_token):
205+
# Check the endpoint is activated
206+
if not is_endpoint_activated(in_end_point_id):
207+
inference_response = {"error": True, "message": "endpoint is not activated."}
208+
logging_inference_request(input_json, inference_response)
209+
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
210+
return inference_response
211+
212+
# Found idle inference device
213+
idle_device, end_point_id, model_id, model_name, model_version, inference_host, inference_output_url = \
214+
found_idle_inference_device(in_end_point_id, in_end_point_name, in_model_name, in_model_version)
215+
if idle_device is None or idle_device == "":
216+
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
217+
return {"error": True, "error_code": status.HTTP_404_NOT_FOUND,
218+
"message": "can not found active inference worker for this endpoint."}
219+
220+
# Start timing for model metrics
221+
model_metrics = FedMLModelMetrics(end_point_id, in_end_point_name,
222+
model_id, in_model_name, model_version,
223+
Settings.model_infer_url,
224+
Settings.redis_addr,
225+
Settings.redis_port,
226+
Settings.redis_password,
227+
version=Settings.version)
228+
# Setting time to the time before authentication and idle device discovery.
229+
model_metrics.set_start_time(start_time)
230+
231+
# Send inference request to idle device
232+
logging.info("inference url {}.".format(inference_output_url))
233+
if inference_output_url != "":
234+
input_list = input_json.get("inputs", input_json)
235+
stream_flag = input_json.get("stream", False)
236+
input_list["stream"] = input_list.get("stream", stream_flag)
237+
output_list = input_json.get("outputs", [])
238+
inference_response = await send_inference_request(
239+
idle_device,
240+
end_point_id,
241+
inference_output_url,
242+
input_list,
243+
output_list,
244+
inference_type=in_return_type)
245+
246+
# Calculate model metrics
247+
try:
248+
model_metrics.calc_metrics(end_point_id, in_end_point_name,
249+
model_id, model_name, model_version,
250+
inference_output_url, idle_device)
251+
except Exception as e:
252+
logging.info("Calculate Inference Metrics Exception: {}".format(traceback.format_exc()))
253+
pass
254+
207255
logging_inference_request(input_json, inference_response)
208256
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
209257
return inference_response
210-
211-
# Found idle inference device
212-
idle_device, end_point_id, model_id, model_name, model_version, inference_host, inference_output_url = \
213-
found_idle_inference_device(in_end_point_id, in_end_point_name, in_model_name, in_model_version)
214-
if idle_device is None or idle_device == "":
258+
else:
259+
inference_response = {"error": True, "message": "token is not valid."}
260+
logging_inference_request(input_json, inference_response)
215261
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
216-
return {"error": True, "error_code": status.HTTP_404_NOT_FOUND,
217-
"message": "can not found active inference worker for this endpoint."}
218-
219-
# Start timing for model metrics
220-
model_metrics = FedMLModelMetrics(end_point_id, in_end_point_name,
221-
model_id, in_model_name, model_version,
222-
Settings.model_infer_url,
223-
Settings.redis_addr,
224-
Settings.redis_port,
225-
Settings.redis_password,
226-
version=Settings.version)
227-
# Setting time to the time before authentication and idle device discovery.
228-
model_metrics.set_start_time(start_time)
229-
230-
# Send inference request to idle device
231-
logging.info("inference url {}.".format(inference_output_url))
232-
if inference_output_url != "":
233-
input_list = input_json.get("inputs", input_json)
234-
stream_flag = input_json.get("stream", False)
235-
input_list["stream"] = input_list.get("stream", stream_flag)
236-
output_list = input_json.get("outputs", [])
237-
inference_response = await send_inference_request(
238-
idle_device,
239-
end_point_id,
240-
inference_output_url,
241-
input_list,
242-
output_list,
243-
inference_type=in_return_type)
262+
return inference_response
244263

245-
# Calculate model metrics
246-
try:
247-
model_metrics.calc_metrics(end_point_id, in_end_point_name,
248-
model_id, model_name, model_version,
249-
inference_output_url, idle_device)
250-
except Exception as e:
251-
logging.info("Calculate Inference Metrics Exception: {}".format(traceback.format_exc()))
252-
pass
253-
254-
logging_inference_request(input_json, inference_response)
255-
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
256-
return inference_response
257-
else:
258-
inference_response = {"error": True, "message": "token is not valid."}
259-
logging_inference_request(input_json, inference_response)
264+
except Exception as e:
265+
logging.error("Inference Exception: {}".format(traceback.format_exc()))
266+
# Need to reduce the pending requests counter in whatever exception that may be raised.
260267
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
261-
return inference_response
262268

263269

264270
def retrieve_info_by_endpoint_id(end_point_id, in_end_point_name=None, in_model_name=None,

0 commit comments

Comments
 (0)