-
Notifications
You must be signed in to change notification settings - Fork 690
Expand file tree
/
Copy pathresponse_middleware.py
More file actions
117 lines (98 loc) · 4.33 KB
/
response_middleware.py
File metadata and controls
117 lines (98 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
from typing import Optional
from starlette.exceptions import HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from common.core.config import settings
from common.utils.utils import SQLBotLogUtil
class ResponseMiddleware(BaseHTTPMiddleware):
instances = []
def __init__(self, app, allow_origins: Optional[list[str]] = None):
super().__init__(app)
self.allow_origins = allow_origins or ["'self'"]
ResponseMiddleware.instances.append(self)
def update_allow_origins(self, new_allow_origins: Optional[list[str]] = None):
if not new_allow_origins:
return
self.allow_origins = list(set(self.allow_origins + new_allow_origins))
async def dispatch(self, request, call_next):
response = await call_next(request)
direct_paths = [
f"{settings.API_V1_STR}/mcp/mcp_question",
f"{settings.API_V1_STR}/mcp/mcp_assistant",
"/openapi.json",
"/docs",
"/redoc"
]
route = request.scope.get("route")
# 获取定义的路径模式,例如 '/items/{item_id}'
path_pattern = '' if not route else route.path_format
if (isinstance(response, JSONResponse)
or request.url.path == f"{settings.API_V1_STR}/openapi.json"
or path_pattern in direct_paths):
return response
if response.status_code != 200:
return response
if response.headers.get("content-type") == "application/json":
try:
body = b""
async for chunk in response.body_iterator:
body += chunk
raw_data = json.loads(body.decode())
if isinstance(raw_data, dict) and all(k in raw_data for k in ["code", "data", "msg"]):
return JSONResponse(
content=raw_data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
}
)
wrapped_data = {
"code": 0,
"data": raw_data,
"msg": None
}
return JSONResponse(
content=wrapped_data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
}
)
except Exception as e:
SQLBotLogUtil.error(f"Response processing error: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content=str(e),
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
}
)
content_type = response.headers.get("content-type", "")
static_content_types = ["text/html", "javascript", "typescript", "css"]
if any(ct in content_type for ct in static_content_types):
if self.allow_origins:
frame_ancestors_value = " ".join(self.allow_origins)
response.headers["Content-Security-Policy"] = f"frame-ancestors {frame_ancestors_value};"
return response
class exception_handler():
@staticmethod
async def http_exception_handler(request: Request, exc: HTTPException):
SQLBotLogUtil.error(f"HTTP Exception: {exc.detail}", exc_info=True)
return JSONResponse(
status_code=exc.status_code,
content=exc.detail,
headers={"Access-Control-Allow-Origin": "*"}
)
@staticmethod
async def global_exception_handler(request: Request, exc: Exception):
SQLBotLogUtil.error(f"Unhandled Exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content=str(exc),
headers={"Access-Control-Allow-Origin": "*"}
)