From c3bfa53ebaa5f73187d86342ca332eeeb87ad5d2 Mon Sep 17 00:00:00 2001 From: ulleo Date: Thu, 7 Aug 2025 16:44:19 +0800 Subject: [PATCH] improve: improve chat record select --- backend/apps/chat/api/chat.py | 58 +++--- backend/apps/chat/curd/chat.py | 218 ++++++++++++++++------ backend/apps/chat/task/llm.py | 64 +++---- backend/apps/datasource/api/datasource.py | 23 ++- backend/apps/db/db.py | 2 +- backend/template.yaml | 3 + frontend/src/api/datasource.ts | 2 +- frontend/src/views/ds/Card.vue | 13 ++ 8 files changed, 258 insertions(+), 125 deletions(-) diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 23d379967..3cc59380e 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -6,9 +6,10 @@ import pandas as pd from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse +from sqlalchemy import and_, select from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ - delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data + delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData from apps.chat.task.llm import LLMService from common.core.deps import CurrentAssistant, SessionDep, CurrentUser @@ -23,48 +24,37 @@ async def chats(session: SessionDep, current_user: CurrentUser): @router.get("/get/{chart_id}") async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant): - try: + def inner(): return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user, current_assistant=current_assistant) - except Exception as e: - raise HTTPException( - status_code=500, - detail=str(e) - ) + + return await asyncio.to_thread(inner) @router.get("/get/with_data/{chart_id}") -async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant): - try: +async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int, + current_assistant: CurrentAssistant): + def inner(): return get_chat_with_records_with_data(chart_id=chart_id, session=session, current_user=current_user, current_assistant=current_assistant) - except Exception as e: - raise HTTPException( - status_code=500, - detail=str(e) - ) + + return await asyncio.to_thread(inner) @router.get("/record/get/{chart_record_id}/data") async def chat_record_data(session: SessionDep, chart_record_id: int): - try: + def inner(): return get_chat_chart_data(chart_record_id=chart_record_id, session=session) - except Exception as e: - raise HTTPException( - status_code=500, - detail=str(e) - ) + + return await asyncio.to_thread(inner) @router.get("/record/get/{chart_record_id}/predict_data") async def chat_predict_data(session: SessionDep, chart_record_id: int): - try: + def inner(): return get_chat_predict_data(chart_record_id=chart_record_id, session=session) - except Exception as e: - raise HTTPException( - status_code=500, - detail=str(e) - ) + + return await asyncio.to_thread(inner) @router.post("/rename") @@ -115,7 +105,8 @@ async def start_chat(session: SessionDep, current_user: CurrentUser): async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, current_assistant: CurrentAssistant): try: - record = session.get(ChatRecord, chat_record_id) + record = get_chat_record_by_id(session, chat_record_id) + if not record: raise HTTPException( status_code=400, @@ -123,7 +114,7 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch ) request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '') - llm_service = LLMService(current_user, request_question, current_assistant) + llm_service = LLMService(current_user, request_question, current_assistant, True) llm_service.set_record(record) llm_service.run_recommend_questions_task_async() except Exception as e: @@ -172,8 +163,17 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch status_code=404, detail="Not Found" ) + record: ChatRecord | None = None + + stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type, + ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where( + and_(ChatRecord.id == chat_record_id)) + result = session.execute(stmt) + for r in result: + record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, + engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, chart=r.chart, + data=r.data) - record = session.query(ChatRecord).get(chat_record_id) if not record: raise HTTPException( status_code=400, diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index 3b51aa328..b715b32e6 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -3,7 +3,7 @@ import orjson import sqlparse -from sqlalchemy import and_, select +from sqlalchemy import and_, select, update from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion from apps.datasource.models.datasource import CoreDatasource @@ -12,6 +12,19 @@ from common.utils.utils import extract_nested_json +def get_chat_record_by_id(session: SessionDep, record_id: int): + record: ChatRecord | None = None + + stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type, + ChatRecord.ai_modal_id, ChatRecord.create_by).where( + and_(ChatRecord.id == record_id)) + result = session.execute(stmt) + for r in result: + record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, + engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by) + return record + + def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]: oid = current_user.oid if current_user.oid is not None else 1 chart_list = session.query(Chat).filter(and_(Chat.create_by == current_user.id, Chat.oid == oid)).order_by( @@ -45,6 +58,17 @@ def delete_chat(session, chart_id) -> str: return f'Chat with id {chart_id} has been deleted' +def get_chart_config(session: SessionDep, chart_record_id: int): + stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chart_record_id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.chart) + except Exception: + pass + return {} + + def get_chat_chart_data(session: SessionDep, chart_record_id: int): stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chart_record_id)) res = session.execute(stmt) @@ -53,7 +77,7 @@ def get_chat_chart_data(session: SessionDep, chart_record_id: int): return orjson.loads(row.data) except Exception: pass - return {} + return [] def get_chat_predict_data(session: SessionDep, chart_record_id: int): @@ -64,7 +88,7 @@ def get_chat_predict_data(session: SessionDep, chart_record_id: int): return orjson.loads(row.predict_data) except Exception: pass - return {} + return [] def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_user: CurrentUser, @@ -185,10 +209,17 @@ def format_record(record: ChatRecord): def list_base_records(session: SessionDep, chart_id: int, current_user: CurrentUser) -> List[ChatRecord]: - record_list = session.query(ChatRecord).filter( - and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id, + stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.full_sql_message, ChatRecord.full_chart_message, + ChatRecord.first_chat, ChatRecord.create_time).where( + and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id, ChatRecord.analysis_record_id.is_(None), ChatRecord.predict_record_id.is_(None))).order_by( - ChatRecord.create_time).all() + ChatRecord.create_time) + result = session.execute(stmt).all() + record_list: List[ChatRecord] = [] + for r in result: + record_list.append( + ChatRecord(id=r.id, chat_id=r.chat_id, create_time=r.create_time, full_sql_message=r.full_sql_message, + full_chart_message=r.full_chart_message, first_chat=r.first_chat)) return record_list @@ -321,7 +352,8 @@ def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer token_usage: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.full_sql_message = full_message record.sql_answer = answer @@ -330,9 +362,13 @@ def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_sql_message=record.full_sql_message, + sql_answer=record.sql_answer, + token_sql=record.token_sql, + ) + + session.execute(stmt) session.commit() @@ -343,7 +379,8 @@ def save_full_analysis_message_and_answer(session: SessionDep, record_id: int, a full_message: str, token_usage: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.full_analysis_message = full_message record.analysis = answer @@ -352,9 +389,13 @@ def save_full_analysis_message_and_answer(session: SessionDep, record_id: int, a result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_analysis_message=record.full_analysis_message, + analysis=record.analysis, + token_analysis=record.token_analysis, + ) + + session.execute(stmt) session.commit() @@ -365,7 +406,8 @@ def save_full_predict_message_and_answer(session: SessionDep, record_id: int, an full_message: str, data: str, token_usage: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.full_predict_message = full_message record.predict = answer record.predict_data = data @@ -375,9 +417,14 @@ def save_full_predict_message_and_answer(session: SessionDep, record_id: int, an result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_predict_message=record.full_predict_message, + predict=record.predict, + predict_data=record.predict_data, + token_predict=record.token_predict + ) + + session.execute(stmt) session.commit() @@ -389,7 +436,8 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i engine_type: str = None, token_usage: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.get(ChatRecord, record_id) + record = get_chat_record_by_id(session, record_id) + record.full_select_datasource_message = full_message record.datasource_select_answer = answer @@ -402,9 +450,22 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + if datasource: + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_select_datasource_message=record.full_select_datasource_message, + datasource_select_answer=record.datasource_select_answer, + token_select_datasource_question=record.token_select_datasource_question, + datasource=record.datasource, + engine_type=record.engine_type, + ) + else: + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_select_datasource_message=record.full_select_datasource_message, + datasource_select_answer=record.datasource_select_answer, + token_select_datasource_question=record.token_select_datasource_question, + ) + + session.execute(stmt) session.commit() @@ -415,9 +476,12 @@ def save_full_recommend_question_message_and_answer(session: SessionDep, record_ full_message: str = '[]', token_usage: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.full_recommended_question_message = full_message - record.recommended_question_answer = orjson.dumps(answer).decode() + + if answer: + record.recommended_question_answer = orjson.dumps(answer).decode() json_str = '[]' if answer and answer.get('content') and answer.get('content') != '': @@ -435,9 +499,14 @@ def save_full_recommend_question_message_and_answer(session: SessionDep, record_ result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_recommended_question_message=record.full_recommended_question_message, + recommended_question_answer=record.recommended_question_answer, + recommended_question=record.recommended_question, + token_recommended_question=record.token_recommended_question + ) + + session.execute(stmt) session.commit() @@ -447,14 +516,18 @@ def save_full_recommend_question_message_and_answer(session: SessionDep, record_ def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + + record = get_chat_record_by_id(session, record_id) + record.sql = sql result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + sql=record.sql + ) + + session.execute(stmt) session.commit() @@ -470,7 +543,8 @@ def save_full_chart_message_and_answer(session: SessionDep, record_id: int, answ full_message: str, token_usage: dict = None) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.full_chart_message = full_message record.chart_answer = answer @@ -479,9 +553,13 @@ def save_full_chart_message_and_answer(session: SessionDep, record_id: int, answ result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + full_chart_message=record.full_chart_message, + chart_answer=record.chart_answer, + token_chart=record.token_chart + ) + + session.execute(stmt) session.commit() @@ -491,14 +569,17 @@ def save_full_chart_message_and_answer(session: SessionDep, record_id: int, answ def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.chart = chart result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + chart=record.chart + ) + + session.execute(stmt) session.commit() @@ -508,14 +589,17 @@ def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord: def save_predict_data(session: SessionDep, record_id: int, data: str = '') -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.predict_data = data result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + predict_data=record.predict_data + ) + + session.execute(stmt) session.commit() @@ -525,16 +609,21 @@ def save_predict_data(session: SessionDep, record_id: int, data: str = '') -> Ch def save_error_message(session: SessionDep, record_id: int, message: str) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.error = message record.finish = True record.finish_time = datetime.datetime.now() result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + error=record.error, + finish=record.finish, + finish_time=record.finish_time + ) + + session.execute(stmt) session.commit() @@ -544,14 +633,17 @@ def save_error_message(session: SessionDep, record_id: int, message: str) -> Cha def save_sql_exec_data(session: SessionDep, record_id: int, data: str) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.data = data result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + data=record.data, + ) + + session.execute(stmt) session.commit() @@ -561,15 +653,19 @@ def save_sql_exec_data(session: SessionDep, record_id: int, data: str) -> ChatRe def finish_record(session: SessionDep, record_id: int) -> ChatRecord: if not record_id: raise Exception("Record id cannot be None") - record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first() + record = get_chat_record_by_id(session, record_id) + record.finish = True record.finish_time = datetime.datetime.now() result = ChatRecord(**record.model_dump()) - session.add(record) - session.flush() - session.refresh(record) + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + finish=record.finish, + finish_time=record.finish_time + ) + + session.execute(stmt) session.commit() @@ -577,9 +673,13 @@ def finish_record(session: SessionDep, record_id: int) -> ChatRecord: def get_old_questions(session: SessionDep, datasource: int): + records = [] if not datasource: - return [] - records = session.query(ChatRecord.question, ChatRecord.create_time).filter(ChatRecord.datasource == datasource, - ChatRecord.question != None).order_by( - ChatRecord.create_time.desc()).limit(20).all() + return records + stmt = select(ChatRecord.question).where( + and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None))).order_by( + ChatRecord.create_time.desc()).limit(20) + result = session.execute(stmt) + for r in result: + records.append(r.question) return records diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 34640ff4e..353df4ee1 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -22,7 +22,8 @@ save_error_message, save_sql_exec_data, save_full_chart_message, save_full_chart_message_and_answer, save_chart, \ finish_record, save_full_analysis_message_and_answer, save_full_predict_message_and_answer, save_predict_data, \ save_full_select_datasource_message_and_answer, save_full_recommend_question_message_and_answer, \ - get_old_questions, save_analysis_predict_record, list_base_records, rename_chat + get_old_questions, save_analysis_predict_record, list_base_records, rename_chat, get_chart_config, \ + get_chat_chart_data from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat from apps.datasource.crud.datasource import get_table_schema from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user @@ -61,7 +62,7 @@ class LLMService: future: Future def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, - current_assistant: Optional[CurrentAssistant] = None): + current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False): self.chunk_list = [] engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) session_maker = sessionmaker(bind=engine) @@ -104,6 +105,13 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, self.ds = (ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None self.chat_question = chat_question self.config = get_default_config() + if no_reasoning: + # only work while using qwen + if self.config.additional_params: + if self.config.additional_params.get('extra_body'): + if self.config.additional_params.get('extra_body').get('enable_thinking'): + del self.config.additional_params['extra_body']['enable_thinking'] + self.chat_question.ai_modal_id = self.config.model_id # Create LLM instance through factory @@ -186,7 +194,7 @@ def set_record(self, record: ChatRecord): self.record = record def get_fields_from_chart(self): - chart_info = orjson.loads(self.record.chart) + chart_info = get_chart_config(self.session, self.record.id) fields = [] if chart_info.get('columns') and len(chart_info.get('columns')) > 0: for column in chart_info.get('columns'): @@ -206,20 +214,15 @@ def get_fields_from_chart(self): def generate_analysis(self): fields = self.get_fields_from_chart() - self.chat_question.fields = orjson.dumps(fields).decode() - self.chat_question.data = orjson.dumps(orjson.loads(self.record.data).get('data')).decode() + data = get_chat_chart_data(self.session, self.record.id) + self.chat_question.data = orjson.dumps(data.get('data')).decode() analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = [] analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) - history_msg = [] - if self.record.full_analysis_message and self.record.full_analysis_message.strip() != '': - history_msg = orjson.loads(self.record.full_analysis_message) - self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, answer='', - full_message=orjson.dumps(history_msg + - [{'type': msg.type, + full_message=orjson.dumps([{'type': msg.type, 'content': msg.content} for msg in analysis_msg]).decode()) @@ -247,29 +250,23 @@ def generate_analysis(self): token_usage=token_usage, answer=orjson.dumps({'content': full_analysis_text, 'reasoning_content': full_thinking_text}).decode(), - full_message=orjson.dumps(history_msg + - [{'type': msg.type, + full_message=orjson.dumps([{'type': msg.type, 'content': msg.content} for msg in analysis_msg]).decode()) def generate_predict(self): fields = self.get_fields_from_chart() - self.chat_question.fields = orjson.dumps(fields).decode() - self.chat_question.data = orjson.dumps(orjson.loads(self.record.data).get('data')).decode() + data = get_chat_chart_data(self.session, self.record.id) + self.chat_question.data = orjson.dumps(data.get('data')).decode() predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question())) - history_msg = [] - if self.record.full_predict_message and self.record.full_predict_message.strip() != '': - history_msg = orjson.loads(self.record.full_predict_message) - self.record = save_full_predict_message_and_answer(session=self.session, record_id=self.record.id, answer='', data='', - full_message=orjson.dumps(history_msg + - [{'type': msg.type, + full_message=orjson.dumps([{'type': msg.type, 'content': msg.content} for msg in predict_msg]).decode()) @@ -298,8 +295,7 @@ def generate_predict(self): answer=orjson.dumps({'content': full_predict_text, 'reasoning_content': full_thinking_text}).decode(), data='', - full_message=orjson.dumps(history_msg + - [{'type': msg.type, + full_message=orjson.dumps([{'type': msg.type, 'content': msg.content} for msg in predict_msg]).decode()) @@ -315,7 +311,7 @@ def generate_recommend_questions_task(self): guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) - old_questions = list(map(lambda q: q[0].strip(), get_old_questions(self.session, self.record.datasource))) + old_questions = list(map(lambda q: q.strip(), get_old_questions(self.session, self.record.datasource))) guess_msg.append( HumanMessage(content=self.chat_question.guess_user_question(orjson.dumps(old_questions).decode()))) @@ -600,7 +596,8 @@ def build_table_filter(self, sql: str, filters: list): return full_filter_text def generate_filter(self, sql: str, tables: List): - filters = get_row_permission_filters(session=self.session, current_user=self.current_user, ds=self.ds, tables=tables) + filters = get_row_permission_filters(session=self.session, current_user=self.current_user, ds=self.ds, + tables=tables) if not filters: return None return self.build_table_filter(sql=sql, filters=filters) @@ -718,7 +715,7 @@ def check_save_chart(self, res: str) -> Dict[str, Any]: except Exception: error = True message = orjson.dumps({'message': 'Cannot parse chart config from answer', - 'traceback': "Cannot parse chart config from answer:\n" + res}).decode() + 'traceback': "Cannot parse chart config from answer:\n" + res}).decode() if error: raise SingleMessageError(message) @@ -745,12 +742,15 @@ def save_error(self, message: str): return save_error_message(session=self.session, record_id=self.record.id, message=message) def save_sql_data(self, data_obj: Dict[str, Any]): - data_result = data_obj.get('data') - if data_result: - data_result = prepare_for_orjson(data_result) - data_obj['data'] = data_result - return save_sql_exec_data(session=self.session, record_id=self.record.id, - data=orjson.dumps(data_obj).decode()) + try: + data_result = data_obj.get('data') + if data_result: + data_result = prepare_for_orjson(data_result) + data_obj['data'] = data_result + return save_sql_exec_data(session=self.session, record_id=self.record.id, + data=orjson.dumps(data_obj).decode()) + except Exception as e: + raise e def finish(self): return finish_record(session=self.session, record_id=self.record.id) diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py index fa7dc0471..2e7a3ef74 100644 --- a/backend/apps/datasource/api/datasource.py +++ b/backend/apps/datasource/api/datasource.py @@ -1,9 +1,11 @@ import asyncio import hashlib import os +import traceback import uuid from typing import List +import orjson import pandas as pd from fastapi import APIRouter, File, UploadFile, HTTPException @@ -109,10 +111,25 @@ async def get_fields(session: SessionDep, id: int, table_name: str): return getFields(session, id, table_name) -@router.post("/execSql/{id}/{sql}") -async def exec_sql(session: SessionDep, id: int, sql: str): +from pydantic import BaseModel + + +class TestObj(BaseModel): + sql: str = None + + +@router.post("/execSql/{id}") +async def exec_sql(session: SessionDep, id: int, obj: TestObj): def inner(): - return execSql(session, id, sql) + data = execSql(session, id, obj.sql) + try: + data_obj = data.get('data') + # print(orjson.dumps(data, option=orjson.OPT_NON_STR_KEYS).decode()) + print(orjson.dumps(data_obj).decode()) + except Exception: + traceback.print_exc() + + return data return await asyncio.to_thread(inner) diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index f27194a25..d081cd040 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -256,7 +256,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str): columns = [item.lower() for item in result.keys()._keys] res = result.fetchall() result_list = [ - {columns[i]: float(value) if isinstance(value, Decimal) else value for i, value in + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in enumerate(tuple_item)} for tuple_item in res ] diff --git a/backend/template.yaml b/backend/template.yaml index 0c86b1ce9..f5a9dbf9c 100644 --- a/backend/template.yaml +++ b/backend/template.yaml @@ -6,8 +6,11 @@ template: 任务: 根据给定的表结构(M-Schema)和用户问题生成符合{engine}数据库引擎规范的sql语句,以及sql中所用到的表名(不要包含schema和database,用数组返回)。 你必须遵守以下规则: + - 只能生成查询用的sql语句,不得生成增删改相关或操作数据库以及操作数据库数据的sql - 不要编造没有提供给你的表结构 - 生成的SQL必须符合{engine}的规范。 + - 若用户要求执行某些sql,若此sql不是查询数据,而是增删改相关或操作数据库以及操作数据库数据等操作,则直接回答: + {{"success":false,"message":"抱歉,我不能执行您指定的SQL语句。"}} - 你的回答必须使用如下JSON格式返回: {{"success":true,"sql":"生成的SQL语句","tables":["表名1","表名2",...]}} - 问题与生成SQL无关时,直接回答: diff --git a/frontend/src/api/datasource.ts b/frontend/src/api/datasource.ts index 493421ec0..51f6ef220 100644 --- a/frontend/src/api/datasource.ts +++ b/frontend/src/api/datasource.ts @@ -11,7 +11,7 @@ export const datasourceApi = { getTablesByConf: (data: any) => request.post('/datasource/getTablesByConf', data), getFields: (id: number, table_name: string) => request.post(`/datasource/getFields/${id}/${table_name}`), - execSql: (id: number, sql: string) => request.post(`/datasource/execSql/${id}/${sql}`), + execSql: (id: number, sql: string) => request.post(`/datasource/execSql/${id}`, { sql: sql }), chooseTables: (id: number, data: any) => request.post(`/datasource/chooseTables/${id}`, data), tableList: (id: number) => request.post(`/datasource/tableList/${id}`), fieldList: (id: number) => request.post(`/datasource/fieldList/${id}`), diff --git a/frontend/src/views/ds/Card.vue b/frontend/src/views/ds/Card.vue index abc172363..98bbdabf0 100644 --- a/frontend/src/views/ds/Card.vue +++ b/frontend/src/views/ds/Card.vue @@ -48,6 +48,18 @@ const handleQuestion = () => { }) } +function runSQL() { + datasourceApi.execSql( + props.id, + 'SELECT TO_CHAR(FLOOR("c"."CREDIT_LIMIT" / 10000) * 10000) || \' - \' || TO_CHAR((FLOOR("c"."CREDIT_LIMIT" / 10000) + 1) * 10000) AS "credit_range",\n' + + ' COUNT(*) AS "customer_count"\n' + + 'FROM "WEI"."CUSTOMERS" "c"\n' + + 'WHERE "c"."CREDIT_LIMIT" IS NOT NULL\n' + + 'GROUP BY FLOOR("c"."CREDIT_LIMIT" / 10000)\n' + + 'ORDER BY FLOOR("c"."CREDIT_LIMIT" / 10000)' + ) +} + const dataTableDetail = () => { emits('dataTableDetail') } @@ -79,6 +91,7 @@ const onClickOutside = () => { {{ num }}
+ test