88from typing import Any , Mapping , MutableMapping , Union
99
1010from fastapi import FastAPI , Request , Response , status
11- from fastapi .responses import StreamingResponse
11+ from fastapi .responses import StreamingResponse , JSONResponse
1212
1313from fedml .computing .scheduler .model_scheduler .device_client_constants import ClientConstants
1414from fedml .computing .scheduler .model_scheduler .device_http_inference_protocol import FedMLHttpInference
@@ -60,7 +60,9 @@ async def auth_middleware(request: Request, call_next):
6060 # Attempt to parse the JSON body.
6161 request_json = await request .json ()
6262 except json .JSONDecodeError :
63- return Response ("Invalid JSON." , status_code = status .HTTP_400_BAD_REQUEST )
63+ return JSONResponse (
64+ {"error" : True , "message" : "Invalid JSON." },
65+ status_code = status .HTTP_400_BAD_REQUEST )
6466
6567 # Get total pending requests.
6668 pending_requests_num = FEDML_MODEL_CACHE .get_pending_requests_counter ()
@@ -84,7 +86,9 @@ async def auth_middleware(request: Request, call_next):
8486
8587 # If timeout threshold is exceeded then cancel and return time out error.
8688 if (mean_latency * pending_requests_num ) > request_timeout_s :
87- return Response ("Request timed out." , status_code = status .HTTP_504_GATEWAY_TIMEOUT )
89+ return JSONResponse (
90+ {"error" : True , "message" : "Request timed out." },
91+ status_code = status .HTTP_504_GATEWAY_TIMEOUT )
8892
8993 response = await call_next (request )
9094 return response
0 commit comments