Skip to content

Commit e667ded

Browse files
authored
Merge pull request #2151 from FedML-AI/dimitris/fix_pending_requests_counter
Adding hash set for counting the number of pending requests per endpoint.
2 parents 6b33065 + c29cf1d commit e667ded

3 files changed

Lines changed: 30 additions & 25 deletions

File tree

python/fedml/computing/scheduler/model_scheduler/device_model_cache.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def set_user_setting_replica_num(self, end_point_id,
139139
"target_queries_per_replica": target_queries_per_replica,
140140
"aggregation_window_size_seconds": aggregation_window_size_seconds,
141141
"scale_down_delay_seconds": scale_down_delay_seconds,
142-
"request_timeout_sec": timeout_s
142+
ServerConstants.INFERENCE_REQUEST_TIMEOUT_KEY: timeout_s
143143
}
144144
try:
145145
self.redis_connection.set(self.get_user_setting_replica_num_key(end_point_id), json.dumps(replica_num_dict))
@@ -974,20 +974,21 @@ def delete_endpoint_scaling_down_decision_time(self, end_point_id) -> bool:
974974
self.FEDML_MODEL_ENDPOINT_SCALING_DOWN_DECISION_TIME_TAG,
975975
end_point_id))
976976

977-
def get_pending_requests_counter(self) -> int:
978-
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
979-
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
980-
return int(self.redis_connection.get(self.FEDML_PENDING_REQUESTS_COUNTER))
977+
def get_pending_requests_counter(self, end_point_id) -> int:
978+
# If the endpoint does not exist inside the Hash collection, set its counter to 0.
979+
if self.redis_connection.hexists(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id):
980+
return int(self.redis_connection.hget(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id))
981+
return 0
981982

982-
def update_pending_requests_counter(self, increase=False, decrease=False) -> int:
983-
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
984-
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
983+
def update_pending_requests_counter(self, end_point_id, increase=False, decrease=False) -> int:
984+
if not self.redis_connection.hexists(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id):
985+
self.redis_connection.hset(self.FEDML_PENDING_REQUESTS_COUNTER, mapping={end_point_id: 0})
985986
if increase:
986-
self.redis_connection.incr(self.FEDML_PENDING_REQUESTS_COUNTER)
987+
self.redis_connection.hincrby(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id, 1)
987988
if decrease:
989+
# Careful on the negative, there is no native function for hash decreases.
990+
self.redis_connection.hincrby(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id, -1)
988991
# Making sure the counter never becomes negative!
989-
if self.get_pending_requests_counter() < 0:
990-
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
991-
else:
992-
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
993-
return self.get_pending_requests_counter()
992+
if self.get_pending_requests_counter(end_point_id) < 0:
993+
self.redis_connection.hset(self.FEDML_PENDING_REQUESTS_COUNTER, mapping={end_point_id: 0})
994+
return self.get_pending_requests_counter(end_point_id)

python/fedml/computing/scheduler/model_scheduler/device_model_inference.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,18 @@ async def auth_middleware(request: Request, call_next):
5555
{"error": True, "message": "Invalid JSON."},
5656
status_code=status.HTTP_400_BAD_REQUEST)
5757

58-
# Get total pending requests.
59-
pending_requests_num = FEDML_MODEL_CACHE.get_pending_requests_counter()
58+
# Get endpoint's total pending requests.
59+
end_point_id = request_json.get("end_point_id", None)
60+
pending_requests_num = FEDML_MODEL_CACHE.get_pending_requests_counter(end_point_id)
6061
if pending_requests_num:
61-
end_point_id = request_json.get("end_point_id", None)
6262
# Fetch metrics of the past k=3 requests.
6363
pask_k_metrics = FEDML_MODEL_CACHE.get_endpoint_metrics(
6464
end_point_id=end_point_id,
6565
k_recent=3)
6666

6767
# Get the request timeout from the endpoint settings.
6868
request_timeout_s = FEDML_MODEL_CACHE.get_endpoint_settings(end_point_id) \
69-
.get("request_timeout_s", ClientConstants.INFERENCE_REQUEST_TIMEOUT)
69+
.get(ServerConstants.INFERENCE_REQUEST_TIMEOUT_KEY, ServerConstants.INFERENCE_REQUEST_TIMEOUT_DEFAULT)
7070

7171
# Only proceed if the past k metrics collection is not empty.
7272
if pask_k_metrics:
@@ -76,7 +76,8 @@ async def auth_middleware(request: Request, call_next):
7676
mean_latency = sum(past_k_latencies_sec) / len(past_k_latencies_sec)
7777

7878
# If timeout threshold is exceeded then cancel and return time out error.
79-
if (mean_latency * pending_requests_num) > request_timeout_s:
79+
should_block = (mean_latency * pending_requests_num) > request_timeout_s
80+
if should_block:
8081
return JSONResponse(
8182
{"error": True, "message": "Request timed out."},
8283
status_code=status.HTTP_504_GATEWAY_TIMEOUT)
@@ -173,7 +174,7 @@ async def _predict(
173174
header=None
174175
) -> Union[MutableMapping[str, Any], Response, StreamingResponse]:
175176
# Always increase the pending requests counter on a new incoming request.
176-
FEDML_MODEL_CACHE.update_pending_requests_counter(increase=True)
177+
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, increase=True)
177178
inference_response = {}
178179

179180
try:
@@ -205,14 +206,14 @@ async def _predict(
205206
if not is_endpoint_activated(in_end_point_id):
206207
inference_response = {"error": True, "message": "endpoint is not activated."}
207208
logging_inference_request(input_json, inference_response)
208-
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
209+
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
209210
return inference_response
210211

211212
# Found idle inference device
212213
idle_device, end_point_id, model_id, model_name, model_version, inference_host, inference_output_url = \
213214
found_idle_inference_device(in_end_point_id, in_end_point_name, in_model_name, in_model_version)
214215
if idle_device is None or idle_device == "":
215-
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
216+
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
216217
return {"error": True, "error_code": status.HTTP_404_NOT_FOUND,
217218
"message": "can not found active inference worker for this endpoint."}
218219

@@ -252,18 +253,18 @@ async def _predict(
252253
pass
253254

254255
logging_inference_request(input_json, inference_response)
255-
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
256+
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
256257
return inference_response
257258
else:
258259
inference_response = {"error": True, "message": "token is not valid."}
259260
logging_inference_request(input_json, inference_response)
260-
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
261+
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
261262
return inference_response
262263

263264
except Exception as e:
264265
logging.error("Inference Exception: {}".format(traceback.format_exc()))
265266
# Need to reduce the pending requests counter in whatever exception that may be raised.
266-
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
267+
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
267268

268269

269270
def retrieve_info_by_endpoint_id(end_point_id, in_end_point_name=None, in_model_name=None,

python/fedml/computing/scheduler/model_scheduler/device_server_constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class ServerConstants(object):
104104
AUTO_DETECT_PUBLIC_IP = "auto_detect_public_ip"
105105
MODEL_INFERENCE_DEFAULT_PORT = 2203
106106
MODEL_CACHE_KEY_EXPIRE_TIME = 1 * 10
107+
108+
INFERENCE_REQUEST_TIMEOUT_KEY = "request_timeout_sec"
109+
INFERENCE_REQUEST_TIMEOUT_DEFAULT = 30
107110
# -----End-----
108111

109112
MODEL_DEPLOYMENT_STAGE1 = {"index": 1, "text": "ReceivedRequest"}

0 commit comments

Comments
 (0)