-
Notifications
You must be signed in to change notification settings - Fork 690
Expand file tree
/
Copy pathassistant.py
More file actions
283 lines (249 loc) · 11.6 KB
/
assistant.py
File metadata and controls
283 lines (249 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import json
import re
import urllib
from typing import Optional
import requests
from fastapi import FastAPI
from sqlmodel import Session, select
from starlette.middleware.cors import CORSMiddleware
from apps.datasource.models.datasource import CoreDatasource
from apps.datasource.utils.utils import aes_encrypt
from apps.system.models.system_model import AssistantModel
from apps.system.schemas.auth import CacheName, CacheNamespace
from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO
from common.core.config import settings
from common.core.db import engine
from common.core.sqlbot_cache import cache
from common.utils.aes_crypto import simple_aes_decrypt
from common.utils.utils import SQLBotLogUtil, get_domain_list, string_to_numeric_hash
from common.core.deps import Trans
from common.core.response_middleware import ResponseMiddleware
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
db_model = session.get(AssistantModel, assistant_id)
return db_model
def get_assistant_user(*, id: int):
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant",
email="sqlbot-inner-assistant@sqlbot.com")
def get_assistant_ds(session: Session, llm_service) -> list[dict]:
assistant: AssistantHeader = llm_service.current_assistant
type = assistant.type
if type == 0 or type == 2:
configuration = assistant.configuration
if configuration:
config: dict[any] = json.loads(configuration)
oid: int = int(config['oid'])
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
CoreDatasource.oid == oid)
if not assistant.online:
public_list: list[int] = config.get('public_list') or None
if public_list:
stmt = stmt.where(CoreDatasource.id.in_(public_list))
else:
return []
""" private_list: list[int] = config.get('private_list') or None
if private_list:
stmt = stmt.where(~CoreDatasource.id.in_(private_list)) """
db_ds_list = session.exec(stmt)
result_list = [
{
"id": ds.id,
"name": ds.name,
"description": ds.description
}
for ds in db_ds_list
]
# filter private ds if offline
return result_list
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
llm_service.out_ds_instance = out_ds_instance
dslist = out_ds_instance.get_simple_ds_list()
# format?
return dslist
def init_dynamic_cors(app: FastAPI):
try:
with Session(engine) as session:
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
seen = set()
unique_domains = []
for item in list_result:
if item.domain:
for domain in get_domain_list(item.domain):
domain = domain.strip()
if domain and domain not in seen:
seen.add(domain)
unique_domains.append(domain)
cors_middleware = None
response_middleware = None
for middleware in app.user_middleware:
if not cors_middleware and middleware.cls == CORSMiddleware:
cors_middleware = middleware
if not response_middleware and middleware.cls == ResponseMiddleware:
response_middleware = middleware
if cors_middleware and response_middleware:
break
updated_origins = list(set(settings.all_cors_origins + unique_domains))
if cors_middleware:
cors_middleware.kwargs['allow_origins'] = updated_origins
if response_middleware:
for instance in ResponseMiddleware.instances:
instance.update_allow_origins(updated_origins)
except Exception as e:
return False, e
class AssistantOutDs:
assistant: AssistantHeader
ds_list: Optional[list[AssistantOutDsSchema]] = None
certificate: Optional[str] = None
request_origin: Optional[str] = None
def __init__(self, assistant: AssistantHeader):
self.assistant = assistant
self.ds_list = None
self.certificate = assistant.certificate
self.request_origin = assistant.request_origin
self.get_ds_from_api()
# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
def get_ds_from_api(self):
config: dict[any] = json.loads(self.assistant.configuration)
endpoint: str = config['endpoint']
endpoint = self.get_complete_endpoint(endpoint=endpoint)
if not endpoint:
raise Exception(
f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]")
certificateList: list[any] = json.loads(self.certificate)
header = {}
cookies = {}
param = {}
for item in certificateList:
if item['target'] == 'header':
header[item['key']] = item['value']
if item['target'] == 'cookie':
cookies[item['key']] = item['value']
if item['target'] == 'param':
param[item['key']] = item['value']
timeout = int(config.get('timeout')) if config.get('timeout') else 10
res = requests.get(url=endpoint, params=param, headers=header, cookies=cookies, timeout=timeout)
if res.status_code == 200:
result_json: dict[any] = json.loads(res.text)
if result_json.get('code') == 0 or result_json.get('code') == 200:
temp_list = result_json.get('data', [])
temp_ds_list = [
self.convert2schema(item, config)
for item in temp_list
]
self.ds_list = temp_ds_list
return self.ds_list
else:
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
else:
SQLBotLogUtil.error(f"Failed to get datasource list from {endpoint}, response: {res}")
raise Exception(f"Failed to get datasource list from {endpoint}, response: {res}")
def get_first_element(self, text: str):
parts = re.split(r'[,;]', text.strip())
first_domain = parts[0].strip()
return first_domain
def get_complete_endpoint(self, endpoint: str) -> str | None:
if endpoint.startswith("http://") or endpoint.startswith("https://"):
return endpoint
domain_text = self.assistant.domain
if not domain_text:
return None
if ',' in domain_text or ';' in domain_text:
return (
self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip(
'/')) + endpoint
else:
return f"{domain_text}{endpoint}"
def get_simple_ds_list(self):
if self.ds_list:
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
else:
raise Exception("Datasource list is not found.")
def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
table_list: list[str] = None) -> str:
ds = self.get_ds(ds_id)
schema_str = ""
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
i = 0
for table in ds.tables:
# 如果传入了 table_list,则只处理在列表中的表
if table_list is not None and table.name not in table_list:
continue
i += 1
schema_table = ''
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {table.name}"
table_comment = table.comment
if table_comment == '':
schema_table += '\n[\n'
else:
schema_table += f", {table_comment}\n[\n"
field_list = []
for field in table.fields:
field_comment = field.comment
if field_comment == '':
field_list.append(f"({field.name}:{field.type})")
else:
field_list.append(f"({field.name}:{field.type}, {field_comment})")
schema_table += ",\n".join(field_list)
schema_table += '\n]\n'
t_obj = {"id": i, "schema_table": schema_table}
tables.append(t_obj)
# do table embedding
# if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
# tables = get_table_embedding(tables, question)
if tables:
for s in tables:
schema_str += s.get('schema_table')
return schema_str
def get_ds(self, ds_id: int, trans: Trans = None):
if self.ds_list:
for ds in self.ds_list:
if ds.id == ds_id:
return ds
else:
raise Exception("Datasource list is not found.")
raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans(
'i18n_data_training.datasource_id_not_found', key=ds_id))
def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
id_marker: str = ''
attr_list = ['name', 'type', 'host', 'port', 'user', 'dataBase', 'schema', 'mode']
if config.get('encrypt', False):
key = config.get('aes_key', None)
iv = config.get('aes_iv', None)
aes_attrs = ['host', 'user', 'password', 'dataBase', 'db_schema', 'schema', 'mode']
for attr in aes_attrs:
if attr in ds_dict and ds_dict[attr]:
try:
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
except Exception as e:
raise Exception(
f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
id = ds_dict.get('id', None)
if not id:
for attr in attr_list:
if attr in ds_dict:
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
id = string_to_numeric_hash(id_marker)
db_schema = ds_dict.get('schema', ds_dict.get('db_schema', ''))
ds_dict.pop("schema", None)
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
class AssistantOutDsFactory:
@staticmethod
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
return AssistantOutDs(assistant)
def get_out_ds_conf(ds: AssistantOutDsSchema, timeout: int = 30) -> str:
conf = {
"host": ds.host or '',
"port": ds.port or 0,
"username": ds.user or '',
"password": ds.password or '',
"database": ds.dataBase or '',
"driver": '',
"extraJdbc": ds.extraParams or '',
"dbSchema": ds.db_schema or '',
"timeout": timeout or 30,
"mode": ds.mode or ''
}
conf["extraJdbc"] = ''
return aes_encrypt(json.dumps(conf))