Skip to content

Commit 2478350

Browse files
authored
Merge pull request #2146 from FedML-AI/dev/v0.7.0
Dev/v0.7.0
2 parents 38e4453 + 85c3ad8 commit 2478350

19 files changed

Lines changed: 421 additions & 232 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
workspace: "./src"
2+
entry_point: "serve_main.py"
3+
bootstrap: |
4+
echo "Bootstrap start..."
5+
sleep 5
6+
echo "Bootstrap finished"
7+
auto_detect_public_ip: true
8+
use_gpu: true
9+
10+
request_timeout_sec: 10
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from fedml.serving import FedMLPredictor
2+
from fedml.serving import FedMLInferenceRunner
3+
import uuid
4+
import torch
5+
6+
# Calculate the number of elements
7+
num_elements = 1_073_741_824 // 4 # using integer division for whole elements
8+
9+
10+
class DummyPredictor(FedMLPredictor):
11+
def __init__(self):
12+
super().__init__()
13+
# Create a tensor with these many elements
14+
tensor = torch.empty(num_elements, dtype=torch.float32)
15+
16+
# Move the tensor to GPU
17+
tensor_gpu = tensor.cuda()
18+
19+
# for debug
20+
with open("/tmp/dummy_gpu_occupier.txt", "w") as f:
21+
f.write("GPU is occupied")
22+
23+
self.worker_id = uuid.uuid4()
24+
25+
def predict(self, request):
26+
return {f"AlohaV0From{self.worker_id}": request}
27+
28+
29+
if __name__ == "__main__":
30+
predictor = DummyPredictor()
31+
fedml_inference_runner = FedMLInferenceRunner(predictor)
32+
fedml_inference_runner.run()

python/fedml/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from copy import deepcopy
32

43
import multiprocess as multiprocessing
54
import os
@@ -9,7 +8,10 @@
98
import torch
109

1110
import fedml
11+
import dotenv
12+
1213
from .computing.scheduler.env.collect_env import collect_env
14+
from fedml.computing.scheduler.env import set_env_kv, load_env
1315
from .constants import (
1416
FEDML_BACKEND_SERVICE_URL_DEV,
1517
FEDML_BACKEND_SERVICE_URL_LOCAL,
@@ -449,11 +451,17 @@ def _run_distributed():
449451

450452

451453
def set_env_version(version):
452-
os.environ['FEDML_ENV_VERSION'] = version
454+
set_env_kv("FEDML_ENV_VERSION", version)
455+
load_env()
453456

454457

455458
def get_env_version():
456-
return "release" if os.environ.get('FEDML_ENV_VERSION') is None else os.environ['FEDML_ENV_VERSION']
459+
load_env()
460+
version = os.getenv("FEDML_ENV_VERSION")
461+
if version is None:
462+
version = "release"
463+
set_env_version(version)
464+
return version
457465

458466

459467
def _get_backend_service():
@@ -510,7 +518,7 @@ def get_local_on_premise_platform_port():
510518

511519
def _get_local_s3_like_service_url():
512520
return FEDML_S3_DOMAIN_LOCAL
513-
521+
514522

515523
from fedml import device
516524
from fedml import data

python/fedml/computing/scheduler/comm_utils/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ class SchedulerConstants:
7878
ENDPOINT_INFERENCE_READY_TIMEOUT = 15
7979
ENDPOINT_STATUS_CHECK_TIMEOUT = 60 * 3
8080

81-
MQTT_INFERENCE_TIMEOUT = 60 * 6
82-
8381
TRAIN_PROVISIONING_TIMEOUT = 60 * 25
8482
TRAIN_STARTING_TIMEOUT = 60 * 15
8583
TRAIN_STOPPING_TIMEOUT = 60 * 5
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .collect_env import load_env, set_env_kv

python/fedml/computing/scheduler/env/collect_env.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import traceback
33

44
import fedml
5+
import dotenv
56
from fedml.computing.scheduler.comm_utils.hardware_utils import HardwareUtil
67
from fedml.computing.scheduler.slave.client_diagnosis import ClientDiagnosis
8+
from ..slave.client_constants import ClientConstants
79

810

911
def collect_env():
@@ -108,3 +110,24 @@ def collect_env():
108110
except Exception as e:
109111
print(f"The connection exception: {traceback.format_exc()}")
110112
pass
113+
114+
115+
def get_env_file():
116+
global_services_dir = ClientConstants.get_global_services_dir()
117+
env_config_file = os.path.join(global_services_dir, ".env")
118+
# Create file if not exists
119+
if not os.path.exists(env_config_file):
120+
with open(env_config_file, 'w') as f:
121+
f.write("")
122+
return env_config_file
123+
124+
125+
def load_env():
126+
env_config_file = get_env_file()
127+
dotenv.load_dotenv(dotenv_path=env_config_file, override=True)
128+
129+
130+
def set_env_kv(key, value):
131+
env_config_file = get_env_file()
132+
dotenv.set_key(env_config_file, key, value)
133+
load_env()

python/fedml/computing/scheduler/model_scheduler/autoscaler/autoscaler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def scale_operation_endpoint(self,
339339

340340
# Fetch all metrics record from the database.
341341
metrics = self.fedml_model_cache.get_endpoint_metrics(
342-
endpoint_id=endpoint_id)
342+
end_point_id=endpoint_id)
343343

344344
# Default to nothing.
345345
scale_op = ScaleOp.NO_OP

python/fedml/computing/scheduler/model_scheduler/device_client_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ClientConstants(object):
9595
INFERENCE_ENGINE_TYPE_INT_DEFAULT = 2
9696
INFERENCE_MODEL_VERSION = "1"
9797
INFERENCE_INFERENCE_SERVER_VERSION = "v2"
98+
INFERENCE_REQUEST_TIMEOUT = 30
9899

99100
MSG_MODELOPS_DEPLOYMENT_STATUS_INITIALIZING = "INITIALIZING"
100101
MSG_MODELOPS_DEPLOYMENT_STATUS_DEPLOYING = "DEPLOYING"

python/fedml/computing/scheduler/model_scheduler/device_http_inference_protocol.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import traceback
2-
from typing import Mapping
3-
from urllib.parse import urlparse
4-
51
import httpx
2+
import traceback
63

74
from .device_client_constants import ClientConstants
8-
import requests
5+
96
from fastapi.responses import Response
107
from fastapi.responses import StreamingResponse
8+
from urllib.parse import urlparse
9+
from typing import Mapping
1110

1211

1312
class FedMLHttpInference:

python/fedml/computing/scheduler/model_scheduler/device_model_cache.py

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class FedMLModelCache(Singleton):
3333

3434
FEDML_KEY_COUNT_PER_SCAN = 1000
3535

36+
FEDML_PENDING_REQUESTS_COUNTER = "FEDML_PENDING_REQUESTS_COUNTER"
37+
3638
def __init__(self):
3739
if not hasattr(self, "redis_pool"):
3840
self.redis_pool = None
@@ -110,7 +112,7 @@ def set_user_setting_replica_num(self, end_point_id,
110112
replica_num: int, enable_auto_scaling: bool = False,
111113
scale_min: int = 0, scale_max: int = 0, state: str = "UNKNOWN",
112114
target_queries_per_replica: int = 60, aggregation_window_size_seconds: int = 60,
113-
scale_down_delay_seconds: int = 120
115+
scale_down_delay_seconds: int = 120, timeout_s: int = 30
114116
) -> bool:
115117
"""
116118
Key: FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG--<end_point_id>
@@ -136,7 +138,8 @@ def set_user_setting_replica_num(self, end_point_id,
136138
"scale_min": scale_min, "scale_max": scale_max, "state": state,
137139
"target_queries_per_replica": target_queries_per_replica,
138140
"aggregation_window_size_seconds": aggregation_window_size_seconds,
139-
"scale_down_delay_seconds": scale_down_delay_seconds
141+
"scale_down_delay_seconds": scale_down_delay_seconds,
142+
"request_timeout_sec": timeout_s
140143
}
141144
try:
142145
self.redis_connection.set(self.get_user_setting_replica_num_key(end_point_id), json.dumps(replica_num_dict))
@@ -362,7 +365,7 @@ def get_idle_device(self, end_point_id, end_point_name,
362365
if "model_status" in result_payload and result_payload["model_status"] == "DEPLOYED":
363366
idle_device_list.append({"device_id": device_id, "end_point_id": end_point_id})
364367

365-
logging.info(f"{len(idle_device_list)} devices has this model on it: {idle_device_list}")
368+
logging.info(f"{len(idle_device_list)} devices this model has on it: {idle_device_list}")
366369

367370
if len(idle_device_list) <= 0:
368371
return None, None
@@ -824,38 +827,37 @@ def get_monitor_metrics_key(self, end_point_id, end_point_name, model_name, mode
824827
end_point_id, end_point_name, model_name, model_version)
825828

826829
def get_endpoint_metrics(self,
827-
endpoint_id,
830+
end_point_id,
828831
k_recent=None) -> List[Any]:
829832
model_deployment_monitor_metrics = list()
830833
try:
831834
key_pattern = "{}*{}*".format(
832835
self.FEDML_MODEL_DEPLOYMENT_MONITOR_TAG,
833-
endpoint_id)
834-
model_deployment_monitor_endpoint_keys = \
836+
end_point_id)
837+
model_deployment_monitor_endpoint_key = \
835838
self.redis_connection.keys(pattern=key_pattern)
836839
# Since the reply is a list, we need to make sure the list
837840
# is non-empty otherwise the index will raise an error.
838-
if model_deployment_monitor_endpoint_keys:
841+
if model_deployment_monitor_endpoint_key:
839842
model_deployment_monitor_endpoint_key = \
840-
model_deployment_monitor_endpoint_keys[0]
841-
else:
842-
raise Exception("Function `get_endpoint_metrics` Key {} does not exist."
843-
.format(key_pattern))
844-
# Set start and end index depending on the size of the
845-
# list and the requested number of most recent records.
846-
num_records = self.redis_connection.llen(name=model_deployment_monitor_endpoint_key)
847-
# if k_most_recent is None, then fetch all by default.
848-
start, end = 0, -1
849-
# if k_most_recent is positive then fetch [-k_most_recent:]
850-
if k_recent and k_recent > 0:
851-
start = num_records - k_recent
852-
model_deployment_monitor_metrics = \
853-
self.redis_connection.lrange(
854-
name=model_deployment_monitor_endpoint_key,
855-
start=start,
856-
end=end)
857-
model_deployment_monitor_metrics = [
858-
json.loads(m) for m in model_deployment_monitor_metrics]
843+
model_deployment_monitor_endpoint_key[0]
844+
845+
# Set start and end index depending on the size of the
846+
# list and the requested number of most recent records.
847+
num_records = self.redis_connection.llen(
848+
name=model_deployment_monitor_endpoint_key)
849+
# if k_most_recent is None, then fetch all by default.
850+
start, end = 0, -1
851+
# if k_most_recent is positive then fetch [-k_most_recent:]
852+
if k_recent and k_recent > 0:
853+
start = num_records - k_recent
854+
model_deployment_monitor_metrics = \
855+
self.redis_connection.lrange(
856+
name=model_deployment_monitor_endpoint_key,
857+
start=start,
858+
end=end)
859+
model_deployment_monitor_metrics = [
860+
json.loads(m) for m in model_deployment_monitor_metrics]
859861

860862
except Exception as e:
861863
logging.error(e)
@@ -868,24 +870,24 @@ def get_endpoint_replicas_results(self, endpoint_id) -> List[Any]:
868870
key_pattern = "{}*{}*".format(
869871
self.FEDML_MODEL_DEPLOYMENT_RESULT_TAG,
870872
endpoint_id)
871-
model_deployment_result_key = \
873+
model_deployment_result_keys = \
872874
self.redis_connection.keys(pattern=key_pattern)
873-
if model_deployment_result_key:
875+
if model_deployment_result_keys:
874876
model_deployment_result_key = \
875-
model_deployment_result_key[0]
877+
model_deployment_result_keys[0]
878+
replicas_results = \
879+
self.redis_connection.lrange(
880+
name=model_deployment_result_key,
881+
start=0,
882+
end=-1)
883+
# Format the result value to a properly formatted json.
884+
for replica_idx, replica in enumerate(replicas_results):
885+
replicas_results[replica_idx] = json.loads(replica)
886+
replicas_results[replica_idx]["result"] = \
887+
json.loads(replicas_results[replica_idx]["result"])
876888
else:
877889
raise Exception("Function `get_endpoint_replicas_results` Key {} does not exist."
878890
.format(key_pattern))
879-
replicas_results = \
880-
self.redis_connection.lrange(
881-
name=model_deployment_result_key,
882-
start=0,
883-
end=-1)
884-
885-
# Format the result value to a properly formatted json.
886-
for replica_idx, replica in enumerate(replicas_results):
887-
replicas_results[replica_idx] = json.loads(replica)
888-
replicas_results[replica_idx]["result"] = json.loads(replicas_results[replica_idx]["result"])
889891

890892
except Exception as e:
891893
logging.error(e)
@@ -898,11 +900,16 @@ def get_endpoint_settings(self, endpoint_id) -> Dict:
898900
key_pattern = "{}*{}*".format(
899901
self.FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG,
900902
endpoint_id)
901-
endpoint_settings = \
903+
904+
endpoint_settings_keys = \
902905
self.redis_connection.keys(pattern=key_pattern)
903-
if endpoint_settings:
906+
907+
if len(endpoint_settings_keys) > 0:
904908
endpoint_settings = \
905-
json.load(endpoint_settings[0])
909+
self.redis_connection.get(endpoint_settings_keys[0])
910+
911+
if not isinstance(endpoint_settings, dict):
912+
endpoint_settings = json.loads(endpoint_settings)
906913
else:
907914
raise Exception("Function `get_endpoint_settings` Key {} does not exist."
908915
.format(key_pattern))
@@ -966,3 +973,21 @@ def delete_endpoint_scaling_down_decision_time(self, end_point_id) -> bool:
966973
return bool(self.redis_connection.hdel(
967974
self.FEDML_MODEL_ENDPOINT_SCALING_DOWN_DECISION_TIME_TAG,
968975
end_point_id))
976+
977+
def get_pending_requests_counter(self) -> int:
978+
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
979+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
980+
return int(self.redis_connection.get(self.FEDML_PENDING_REQUESTS_COUNTER))
981+
982+
def update_pending_requests_counter(self, increase=False, decrease=False) -> int:
983+
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
984+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
985+
if increase:
986+
self.redis_connection.incr(self.FEDML_PENDING_REQUESTS_COUNTER)
987+
if decrease:
988+
# Making sure the counter never becomes negative!
989+
if self.get_pending_requests_counter() < 0:
990+
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
991+
else:
992+
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
993+
return self.get_pending_requests_counter()

0 commit comments

Comments
 (0)