Skip to content

Commit ec4e199

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 91c2765 + 63cd878 commit ec4e199

2 files changed

Lines changed: 49 additions & 28 deletions

File tree

backend/apps/chat/task/llm.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -516,31 +516,9 @@ def generate_sql(self):
516516
[{'type': msg.type, 'content': msg.content} for msg in
517517
self.sql_message]).decode())
518518

519-
def generate_filter(self, sql: str, tables: List):
520-
table_list = self.session.query(CoreTable).filter(
521-
and_(CoreTable.ds_id == self.ds.id, CoreTable.table_name.in_(tables))
522-
).all()
523-
524-
filters = []
525-
for table in table_list:
526-
row_permissions = self.session.query(DsPermission).filter(
527-
and_(DsPermission.table_id == table.id, DsPermission.type == 'row')).all()
528-
res: List[PermissionDTO] = []
529-
if row_permissions is not None:
530-
for permission in row_permissions:
531-
# check permission and user in same rules
532-
obj = self.session.query(DsRules).filter(
533-
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
534-
or_(DsRules.user_list.op('@>')(cast([f'{self.current_user.id}'], JSONB)),
535-
DsRules.user_list.op('@>')(cast([self.current_user.id], JSONB))))
536-
).first()
537-
if obj is not None:
538-
res.append(transRecord2DTO(self.session, permission))
539-
where_str = transFilterTree(self.session, res, self.ds)
540-
filters.append({"table": table.table_name, "filter": where_str})
541-
519+
520+
def build_table_filter(self, sql: str, filters: list):
542521
filter = json.dumps(filters, ensure_ascii=False)
543-
544522
self.chat_question.sql = sql
545523
self.chat_question.filter = filter
546524
msg: List[Union[BaseMessage, dict[str, Any]]] = []
@@ -588,7 +566,42 @@ def generate_filter(self, sql: str, tables: List):
588566
# analysis_msg]).decode())
589567
SQLBotLogUtil.info(full_filter_text)
590568
return full_filter_text
569+
570+
def generate_filter(self, sql: str, tables: List):
571+
table_list = self.session.query(CoreTable).filter(
572+
and_(CoreTable.ds_id == self.ds.id, CoreTable.table_name.in_(tables))
573+
).all()
591574

575+
filters = []
576+
for table in table_list:
577+
row_permissions = self.session.query(DsPermission).filter(
578+
and_(DsPermission.table_id == table.id, DsPermission.type == 'row')).all()
579+
res: List[PermissionDTO] = []
580+
if row_permissions is not None:
581+
for permission in row_permissions:
582+
# check permission and user in same rules
583+
obj = self.session.query(DsRules).filter(
584+
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
585+
or_(DsRules.user_list.op('@>')(cast([f'{self.current_user.id}'], JSONB)),
586+
DsRules.user_list.op('@>')(cast([self.current_user.id], JSONB))))
587+
).first()
588+
if obj is not None:
589+
res.append(transRecord2DTO(self.session, permission))
590+
where_str = transFilterTree(self.session, res, self.ds)
591+
filters.append({"table": table.table_name, "filter": where_str})
592+
593+
return self.build_table_filter(sql=sql, filters=filters)
594+
595+
def generate_assistant_filter(self, sql, tables: List):
596+
ds: AssistantOutDsSchema = self.ds
597+
filters = []
598+
for table in ds.tables:
599+
if table.name in tables and table.rule:
600+
filters.append({"table": table.name, "filter": table.rule})
601+
if not filters:
602+
return None
603+
return self.build_table_filter(sql=sql, filters=filters)
604+
592605
def generate_chart(self):
593606
# append current question
594607
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question()))
@@ -802,7 +815,7 @@ def run_task(self, in_chat: bool = True):
802815
SQLBotLogUtil.info(full_sql_text)
803816

804817
# todo row permission
805-
if is_normal_user(self.current_user):
818+
if is_normal_user(self.current_user) or (self.current_assistant and self.current_assistant.type == 1):
806819
sql_json_str = extract_nested_json(full_sql_text)
807820
data = orjson.loads(sql_json_str)
808821

@@ -819,9 +832,16 @@ def run_task(self, in_chat: bool = True):
819832
if sql.strip() == '':
820833
raise Exception("SQL query is empty")
821834

822-
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
823-
SQLBotLogUtil.info(sql_result)
824-
sql = self.check_save_sql(res=sql_result)
835+
if self.current_assistant:
836+
sql_result = self.generate_assistant_filter(data.get('sql'), data.get('tables'))
837+
else:
838+
sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables
839+
840+
if sql_result:
841+
SQLBotLogUtil.info(sql_result)
842+
sql = self.check_save_sql(res=sql_result)
843+
else:
844+
sql = self.check_save_sql(res=full_sql_text)
825845
else:
826846
sql = self.check_save_sql(res=full_sql_text)
827847

backend/apps/system/crud/assistant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
192192
extraJdbc=ds.extraParams,
193193
dbSchema=ds.db_schema or ''
194194
)
195+
conf.extraJdbc = ''
195196
from apps.db.db import get_uri_from_config
196197
uri = get_uri_from_config(ds.type, conf)
197198
if ds.type == "pg" and ds.db_schema:

0 commit comments

Comments
 (0)