@@ -173,92 +173,98 @@ async def _predict(
173173 header = None
174174) -> Union [MutableMapping [str , Any ], Response , StreamingResponse ]:
175175
176+ # Always increase the pending requests counter on a new incoming request.
176177 FEDML_MODEL_CACHE .update_pending_requests_counter (increase = True )
177178 inference_response = {}
178179
179- in_end_point_id = end_point_id
180- in_end_point_name = input_json .get ("end_point_name" , None )
181- in_model_name = input_json .get ("model_name" , None )
182- in_model_version = input_json .get ("model_version" , None )
183- in_end_point_token = input_json .get ("token" , None )
184- in_return_type = "default"
185- if header is not None :
186- in_return_type = header .get ("Accept" , "default" )
187-
188- if in_model_version is None :
189- in_model_version = "*" # * | latest | specific version
190-
191- start_time = time .time_ns ()
192-
193- # Allow missing end_point_name and model_name in the input parameters.
194- if in_model_name is None or in_end_point_name is None :
195- ret_endpoint_name , ret_model_name = retrieve_info_by_endpoint_id (in_end_point_id , in_end_point_name )
196- if in_model_name is None :
197- in_model_name = ret_model_name
198- if in_end_point_name is None :
199- in_end_point_name = ret_endpoint_name
200-
201- # Authenticate request token
202- inference_response = {}
203- if auth_request_token (in_end_point_id , in_end_point_name , in_model_name , in_end_point_token ):
204- # Check the endpoint is activated
205- if not is_endpoint_activated (in_end_point_id ):
206- inference_response = {"error" : True , "message" : "endpoint is not activated." }
180+ try :
181+ in_end_point_id = end_point_id
182+ in_end_point_name = input_json .get ("end_point_name" , None )
183+ in_model_name = input_json .get ("model_name" , None )
184+ in_model_version = input_json .get ("model_version" , None )
185+ in_end_point_token = input_json .get ("token" , None )
186+ in_return_type = "default"
187+ if header is not None :
188+ in_return_type = header .get ("Accept" , "default" )
189+
190+ if in_model_version is None :
191+ in_model_version = "*" # * | latest | specific version
192+
193+ start_time = time .time_ns ()
194+
195+ # Allow missing end_point_name and model_name in the input parameters.
196+ if in_model_name is None or in_end_point_name is None :
197+ ret_endpoint_name , ret_model_name = retrieve_info_by_endpoint_id (in_end_point_id , in_end_point_name )
198+ if in_model_name is None :
199+ in_model_name = ret_model_name
200+ if in_end_point_name is None :
201+ in_end_point_name = ret_endpoint_name
202+
203+ # Authenticate request token
204+ if auth_request_token (in_end_point_id , in_end_point_name , in_model_name , in_end_point_token ):
205+ # Check the endpoint is activated
206+ if not is_endpoint_activated (in_end_point_id ):
207+ inference_response = {"error" : True , "message" : "endpoint is not activated." }
208+ logging_inference_request (input_json , inference_response )
209+ FEDML_MODEL_CACHE .update_pending_requests_counter (decrease = True )
210+ return inference_response
211+
212+ # Found idle inference device
213+ idle_device , end_point_id , model_id , model_name , model_version , inference_host , inference_output_url = \
214+ found_idle_inference_device (in_end_point_id , in_end_point_name , in_model_name , in_model_version )
215+ if idle_device is None or idle_device == "" :
216+ FEDML_MODEL_CACHE .update_pending_requests_counter (decrease = True )
217+ return {"error" : True , "error_code" : status .HTTP_404_NOT_FOUND ,
218+ "message" : "can not found active inference worker for this endpoint." }
219+
220+ # Start timing for model metrics
221+ model_metrics = FedMLModelMetrics (end_point_id , in_end_point_name ,
222+ model_id , in_model_name , model_version ,
223+ Settings .model_infer_url ,
224+ Settings .redis_addr ,
225+ Settings .redis_port ,
226+ Settings .redis_password ,
227+ version = Settings .version )
228+ # Setting time to the time before authentication and idle device discovery.
229+ model_metrics .set_start_time (start_time )
230+
231+ # Send inference request to idle device
232+ logging .info ("inference url {}." .format (inference_output_url ))
233+ if inference_output_url != "" :
234+ input_list = input_json .get ("inputs" , input_json )
235+ stream_flag = input_json .get ("stream" , False )
236+ input_list ["stream" ] = input_list .get ("stream" , stream_flag )
237+ output_list = input_json .get ("outputs" , [])
238+ inference_response = await send_inference_request (
239+ idle_device ,
240+ end_point_id ,
241+ inference_output_url ,
242+ input_list ,
243+ output_list ,
244+ inference_type = in_return_type )
245+
246+ # Calculate model metrics
247+ try :
248+ model_metrics .calc_metrics (end_point_id , in_end_point_name ,
249+ model_id , model_name , model_version ,
250+ inference_output_url , idle_device )
251+ except Exception as e :
252+ logging .info ("Calculate Inference Metrics Exception: {}" .format (traceback .format_exc ()))
253+ pass
254+
207255 logging_inference_request (input_json , inference_response )
208256 FEDML_MODEL_CACHE .update_pending_requests_counter (decrease = True )
209257 return inference_response
210-
211- # Found idle inference device
212- idle_device , end_point_id , model_id , model_name , model_version , inference_host , inference_output_url = \
213- found_idle_inference_device (in_end_point_id , in_end_point_name , in_model_name , in_model_version )
214- if idle_device is None or idle_device == "" :
258+ else :
259+ inference_response = {"error" : True , "message" : "token is not valid." }
260+ logging_inference_request (input_json , inference_response )
215261 FEDML_MODEL_CACHE .update_pending_requests_counter (decrease = True )
216- return {"error" : True , "error_code" : status .HTTP_404_NOT_FOUND ,
217- "message" : "can not found active inference worker for this endpoint." }
218-
219- # Start timing for model metrics
220- model_metrics = FedMLModelMetrics (end_point_id , in_end_point_name ,
221- model_id , in_model_name , model_version ,
222- Settings .model_infer_url ,
223- Settings .redis_addr ,
224- Settings .redis_port ,
225- Settings .redis_password ,
226- version = Settings .version )
227- # Setting time to the time before authentication and idle device discovery.
228- model_metrics .set_start_time (start_time )
229-
230- # Send inference request to idle device
231- logging .info ("inference url {}." .format (inference_output_url ))
232- if inference_output_url != "" :
233- input_list = input_json .get ("inputs" , input_json )
234- stream_flag = input_json .get ("stream" , False )
235- input_list ["stream" ] = input_list .get ("stream" , stream_flag )
236- output_list = input_json .get ("outputs" , [])
237- inference_response = await send_inference_request (
238- idle_device ,
239- end_point_id ,
240- inference_output_url ,
241- input_list ,
242- output_list ,
243- inference_type = in_return_type )
262+ return inference_response
244263
245- # Calculate model metrics
246- try :
247- model_metrics .calc_metrics (end_point_id , in_end_point_name ,
248- model_id , model_name , model_version ,
249- inference_output_url , idle_device )
250- except Exception as e :
251- logging .info ("Calculate Inference Metrics Exception: {}" .format (traceback .format_exc ()))
252- pass
253-
254- logging_inference_request (input_json , inference_response )
255- FEDML_MODEL_CACHE .update_pending_requests_counter (decrease = True )
256- return inference_response
257- else :
258- inference_response = {"error" : True , "message" : "token is not valid." }
259- logging_inference_request (input_json , inference_response )
264+ except Exception as e :
265+ logging .error ("Inference Exception: {}" .format (traceback .format_exc ()))
266+ # Need to reduce the pending requests counter in whatever exception that may be raised.
260267 FEDML_MODEL_CACHE .update_pending_requests_counter (decrease = True )
261- return inference_response
262268
263269
264270def retrieve_info_by_endpoint_id (end_point_id , in_end_point_name = None , in_model_name = None ,
0 commit comments