1818 save_error_message , save_sql_exec_data , save_full_chart_message , save_full_chart_message_and_answer , save_chart , \
1919 finish_record , save_full_analysis_message_and_answer , save_full_predict_message_and_answer , save_predict_data , \
2020 save_full_select_datasource_message_and_answer , save_full_recommend_question_message_and_answer , \
21- get_old_questions , save_analysis_predict_record , list_base_records
22- from apps .chat .models .chat_model import ChatQuestion , ChatRecord , Chat
21+ get_old_questions , save_analysis_predict_record , list_base_records , rename_chat
22+ from apps .chat .models .chat_model import ChatQuestion , ChatRecord , Chat , RenameChat
2323from apps .datasource .crud .datasource import get_table_schema
2424from apps .datasource .models .datasource import CoreDatasource
2525from apps .db .db import exec_sql
@@ -45,13 +45,15 @@ class LLMService:
4545 session : SessionDep
4646 current_user : CurrentUser
4747 current_assistant : Optional [CurrentAssistant ] = None
48+ change_title : bool = False
4849
49- def __init__ (self , session : SessionDep , current_user : CurrentUser , chat_question : ChatQuestion , current_assistant : Optional [CurrentAssistant ] = None ):
50+ def __init__ (self , session : SessionDep , current_user : CurrentUser , chat_question : ChatQuestion ,
51+ current_assistant : Optional [CurrentAssistant ] = None ):
5052
5153 self .session = session
5254 self .current_user = current_user
5355 self .current_assistant = current_assistant
54- #chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
56+ # chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
5557 chat_id = chat_question .chat_id
5658 chat : Chat = self .session .get (Chat , chat_id )
5759 if not chat :
@@ -71,6 +73,8 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
7173 list_base_records (session = self .session ,
7274 current_user = current_user ,
7375 chart_id = chat_id ))))
76+ self .change_title = len (history_records ) == 0
77+
7478 # get schema
7579 if ds :
7680 chat_question .db_schema = get_table_schema (session = self .session , ds = ds )
@@ -335,7 +339,7 @@ def select_datasource(self):
335339 datasource_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
336340 datasource_msg .append (SystemMessage (self .chat_question .datasource_sys_question ()))
337341 if self .current_assistant :
338- _ds_list = get_assistant_ds (session = self .session , assistant = self .current_assistant )
342+ _ds_list = get_assistant_ds (session = self .session , assistant = self .current_assistant )
339343 else :
340344 _ds_list = self .session .exec (select (CoreDatasource ).options (
341345 load_only (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ))).all ()
@@ -620,6 +624,14 @@ def run_task(llm_service: LLMService, in_chat: bool = True):
620624 if in_chat :
621625 yield orjson .dumps ({'type' : 'id' , 'id' : llm_service .get_record ().id }).decode () + '\n \n '
622626
627+ # return title
628+ if llm_service .change_title :
629+ if llm_service .chat_question .question or llm_service .chat_question .question .strip () != '' :
630+ brief = rename_chat (session = llm_service .session ,
631+ rename_object = RenameChat (id = llm_service .get_record ().chat_id ,
632+ brief = llm_service .chat_question .question .strip ()[:20 ]))
633+ yield orjson .dumps ({'type' : 'brief' , 'brief' : brief }).decode () + '\n \n '
634+
623635 # select datasource if datasource is none
624636 if not llm_service .ds :
625637 ds_res = llm_service .select_datasource ()
0 commit comments