Skip to content

Commit 10c5e17

Browse files
Merge pull request #2142 from FedML-AI/dimitris/fail_fast_policy_merge
Fast Fail and Timeout Enforcement Policy for Model Deploy Endpoints
2 parents 9a8f307 + b0a55ad commit 10c5e17

11 files changed

Lines changed: 308 additions & 162 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
workspace: "./src"
2+
entry_point: "serve_main.py"
3+
bootstrap: |
4+
echo "Bootstrap start..."
5+
sleep 5
6+
echo "Bootstrap finished"
7+
auto_detect_public_ip: true
8+
use_gpu: true
9+
10+
request_timeout_sec: 10
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from fedml.serving import FedMLPredictor
2+
from fedml.serving import FedMLInferenceRunner
3+
import uuid
4+
import torch
5+
6+
# Calculate the number of elements
7+
num_elements = 1_073_741_824 // 4 # using integer division for whole elements
8+
9+
10+
class DummyPredictor(FedMLPredictor):
11+
def __init__(self):
12+
super().__init__()
13+
# Create a tensor with these many elements
14+
tensor = torch.empty(num_elements, dtype=torch.float32)
15+
16+
# Move the tensor to GPU
17+
tensor_gpu = tensor.cuda()
18+
19+
# for debug
20+
with open("/tmp/dummy_gpu_occupier.txt", "w") as f:
21+
f.write("GPU is occupied")
22+
23+
self.worker_id = uuid.uuid4()
24+
25+
def predict(self, request):
26+
return {f"AlohaV0From{self.worker_id}": request}
27+
28+
29+
if __name__ == "__main__":
30+
predictor = DummyPredictor()
31+
fedml_inference_runner = FedMLInferenceRunner(predictor)
32+
fedml_inference_runner.run()

python/fedml/computing/scheduler/comm_utils/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ class SchedulerConstants:
7878
ENDPOINT_INFERENCE_READY_TIMEOUT = 15
7979
ENDPOINT_STATUS_CHECK_TIMEOUT = 60 * 3
8080

81-
MQTT_INFERENCE_TIMEOUT = 60 * 6
82-
8381
TRAIN_PROVISIONING_TIMEOUT = 60 * 25
8482
TRAIN_STARTING_TIMEOUT = 60 * 15
8583
TRAIN_STOPPING_TIMEOUT = 60 * 5

python/fedml/computing/scheduler/model_scheduler/device_client_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ClientConstants(object):
9595
INFERENCE_ENGINE_TYPE_INT_DEFAULT = 2
9696
INFERENCE_MODEL_VERSION = "1"
9797
INFERENCE_INFERENCE_SERVER_VERSION = "v2"
98+
INFERENCE_REQUEST_TIMEOUT = 30
9899

99100
MSG_MODELOPS_DEPLOYMENT_STATUS_INITIALIZING = "INITIALIZING"
100101
MSG_MODELOPS_DEPLOYMENT_STATUS_DEPLOYING = "DEPLOYING"

python/fedml/computing/scheduler/model_scheduler/device_http_inference_protocol.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import traceback
2-
from typing import Mapping
3-
from urllib.parse import urlparse
4-
51
import httpx
2+
import traceback
63

74
from .device_client_constants import ClientConstants
8-
import requests
5+
96
from fastapi.responses import Response
107
from fastapi.responses import StreamingResponse
8+
from urllib.parse import urlparse
9+
from typing import Mapping
1110

1211

1312
class FedMLHttpInference:

python/fedml/computing/scheduler/model_scheduler/device_model_cache.py

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class FedMLModelCache(Singleton):
3333

3434
FEDML_KEY_COUNT_PER_SCAN = 1000
3535

36+
FEDML_PENDING_REQUESTS_COUNTER = "FEDML_PENDING_REQUESTS_COUNTER"
37+
3638
def __init__(self):
3739
if not hasattr(self, "redis_pool"):
3840
self.redis_pool = None
@@ -110,7 +112,7 @@ def set_user_setting_replica_num(self, end_point_id,
110112
replica_num: int, enable_auto_scaling: bool = False,
111113
scale_min: int = 0, scale_max: int = 0, state: str = "UNKNOWN",
112114
target_queries_per_replica: int = 60, aggregation_window_size_seconds: int = 60,
113-
scale_down_delay_seconds: int = 120
115+
scale_down_delay_seconds: int = 120, timeout_s: int = 30
114116
) -> bool:
115117
"""
116118
Key: FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG--<end_point_id>
@@ -136,7 +138,8 @@ def set_user_setting_replica_num(self, end_point_id,
136138
"scale_min": scale_min, "scale_max": scale_max, "state": state,
137139
"target_queries_per_replica": target_queries_per_replica,
138140
"aggregation_window_size_seconds": aggregation_window_size_seconds,
139-
"scale_down_delay_seconds": scale_down_delay_seconds
141+
"scale_down_delay_seconds": scale_down_delay_seconds,
142+
"request_timeout_sec": timeout_s
140143
}
141144
try:
142145
self.redis_connection.set(self.get_user_setting_replica_num_key(end_point_id), json.dumps(replica_num_dict))
@@ -362,7 +365,7 @@ def get_idle_device(self, end_point_id, end_point_name,
362365
if "model_status" in result_payload and result_payload["model_status"] == "DEPLOYED":
363366
idle_device_list.append({"device_id": device_id, "end_point_id": end_point_id})
364367

365-
logging.info(f"{len(idle_device_list)} devices has this model on it: {idle_device_list}")
368+
logging.info(f"{len(idle_device_list)} devices this model has on it: {idle_device_list}")
366369

367370
if len(idle_device_list) <= 0:
368371
return None, None
@@ -824,38 +827,37 @@ def get_monitor_metrics_key(self, end_point_id, end_point_name, model_name, mode
824827
end_point_id, end_point_name, model_name, model_version)
825828

826829
def get_endpoint_metrics(self,
827-
endpoint_id,
830+
end_point_id,
828831
k_recent=None) -> List[Any]:
829832
model_deployment_monitor_metrics = list()
830833
try:
831834
key_pattern = "{}*{}*".format(
832835
self.FEDML_MODEL_DEPLOYMENT_MONITOR_TAG,
833-
endpoint_id)
834-
model_deployment_monitor_endpoint_keys = \
836+
end_point_id)
837+
model_deployment_monitor_endpoint_key = \
835838
self.redis_connection.keys(pattern=key_pattern)
836839
# Since the reply is a list, we need to make sure the list
837840
# is non-empty otherwise the index will raise an error.
838-
if model_deployment_monitor_endpoint_keys:
841+
if model_deployment_monitor_endpoint_key:
839842
model_deployment_monitor_endpoint_key = \
840-
model_deployment_monitor_endpoint_keys[0]
841-
else:
842-
raise Exception("Function `get_endpoint_metrics` Key {} does not exist."
843-
.format(key_pattern))
844-
# Set start and end index depending on the size of the
845-
# list and the requested number of most recent records.
846-
num_records = self.redis_connection.llen(name=model_deployment_monitor_endpoint_key)
847-
# if k_most_recent is None, then fetch all by default.
848-
start, end = 0, -1
849-
# if k_most_recent is positive then fetch [-k_most_recent:]
850-
if k_recent and k_recent > 0:
851-
start = num_records - k_recent
852-
model_deployment_monitor_metrics = \
853-
self.redis_connection.lrange(
854-
name=model_deployment_monitor_endpoint_key,
855-
start=start,
856-
end=end)
857-
model_deployment_monitor_metrics = [
858-
json.loads(m) for m in model_deployment_monitor_metrics]
843+
model_deployment_monitor_endpoint_key[0]
844+
845+
# Set start and end index depending on the size of the
846+
# list and the requested number of most recent records.
847+
num_records = self.redis_connection.llen(
848+
name=model_deployment_monitor_endpoint_key)
849+
# if k_most_recent is None, then fetch all by default.
850+
start, end = 0, -1
851+
# if k_most_recent is positive then fetch [-k_most_recent:]
852+
if k_recent and k_recent > 0:
853+
start = num_records - k_recent
854+
model_deployment_monitor_metrics = \
855+
self.redis_connection.lrange(
856+
name=model_deployment_monitor_endpoint_key,
857+
start=start,
858+
end=end)
859+
model_deployment_monitor_metrics = [
860+
json.loads(m) for m in model_deployment_monitor_metrics]
859861

860862
except Exception as e:
861863
logging.error(e)
@@ -868,24 +870,24 @@ def get_endpoint_replicas_results(self, endpoint_id) -> List[Any]:
868870
key_pattern = "{}*{}*".format(
869871
self.FEDML_MODEL_DEPLOYMENT_RESULT_TAG,
870872
endpoint_id)
871-
model_deployment_result_key = \
873+
model_deployment_result_keys = \
872874
self.redis_connection.keys(pattern=key_pattern)
873-
if model_deployment_result_key:
875+
if model_deployment_result_keys:
874876
model_deployment_result_key = \
875-
model_deployment_result_key[0]
877+
model_deployment_result_keys[0]
878+
replicas_results = \
879+
self.redis_connection.lrange(
880+
name=model_deployment_result_key,
881+
start=0,
882+
end=-1)
883+
# Format the result value to a properly formatted json.
884+
for replica_idx, replica in enumerate(replicas_results):
885+
replicas_results[replica_idx] = json.loads(replica)
886+
replicas_results[replica_idx]["result"] = \
887+
json.loads(replicas_results[replica_idx]["result"])
876888
else:
877889
raise Exception("Function `get_endpoint_replicas_results` Key {} does not exist."
878890
.format(key_pattern))
879-
replicas_results = \
880-
self.redis_connection.lrange(
881-
name=model_deployment_result_key,
882-
start=0,
883-
end=-1)
884-
885-
# Format the result value to a properly formatted json.
886-
for replica_idx, replica in enumerate(replicas_results):
887-
replicas_results[replica_idx] = json.loads(replica)
888-
replicas_results[replica_idx]["result"] = json.loads(replicas_results[replica_idx]["result"])
889891

890892
except Exception as e:
891893
logging.error(e)
@@ -898,11 +900,16 @@ def get_endpoint_settings(self, endpoint_id) -> Dict:
898900
key_pattern = "{}*{}*".format(
899901
self.FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG,
900902
endpoint_id)
901-
endpoint_settings = \
903+
904+
endpoint_settings_keys = \
902905
self.redis_connection.keys(pattern=key_pattern)
903-
if endpoint_settings:
906+
907+
if len(endpoint_settings_keys) > 0:
904908
endpoint_settings = \
905-
json.load(endpoint_settings[0])
909+
self.redis_connection.get(endpoint_settings_keys[0])
910+
911+
if not isinstance(endpoint_settings, dict):
912+
endpoint_settings = json.loads(endpoint_settings)
906913
else:
907914
raise Exception("Function `get_endpoint_settings` Key {} does not exist."
908915
.format(key_pattern))
@@ -966,3 +973,21 @@ def delete_endpoint_scaling_down_decision_time(self, end_point_id) -> bool:
966973
return bool(self.redis_connection.hdel(
967974
self.FEDML_MODEL_ENDPOINT_SCALING_DOWN_DECISION_TIME_TAG,
968975
end_point_id))
976+
977+
def get_pending_requests_counter(self) -> int:
978+
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
979+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
980+
return int(self.redis_connection.get(self.FEDML_PENDING_REQUESTS_COUNTER))
981+
982+
def update_pending_requests_counter(self, increase=False, decrease=False) -> int:
983+
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
984+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
985+
if increase:
986+
self.redis_connection.incr(self.FEDML_PENDING_REQUESTS_COUNTER)
987+
if decrease:
988+
# Making sure the counter never becomes negative!
989+
if self.get_pending_requests_counter() < 0:
990+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
991+
else:
992+
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
993+
return self.get_pending_requests_counter()

0 commit comments

Comments
 (0)