|
1 | | -import asyncio |
2 | 1 | import json |
| 2 | +import traceback |
3 | 3 | from typing import List |
4 | 4 |
|
5 | 5 | from fastapi import APIRouter, HTTPException |
6 | 6 | from fastapi.responses import StreamingResponse |
7 | 7 | from sqlmodel import select |
8 | 8 |
|
9 | | -from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, save_question, save_answer, rename_chat, \ |
| 9 | +from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ |
10 | 10 | delete_chat, list_records |
11 | 11 | from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat, ChatQuestion |
12 | 12 | from apps.chat.task.llm import LLMService |
@@ -79,7 +79,6 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que |
79 | 79 | Returns: |
80 | 80 | Streaming response with analysis results |
81 | 81 | """ |
82 | | - question = request_question.question |
83 | 82 |
|
84 | 83 | chat = session.query(Chat).filter(Chat.id == request_question.chat_id).first() |
85 | 84 | if not chat: |
@@ -113,46 +112,51 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que |
113 | 112 | chart_id=request_question.chat_id) |
114 | 113 | # get schema |
115 | 114 | request_question.db_schema = get_table_schema(session=session, ds=ds) |
116 | | - llm_service = LLMService(request_question, history_records, ds, aimodel) |
| 115 | + llm_service = LLMService(request_question, history_records, CoreDatasource(**ds.model_dump()), aimodel) |
117 | 116 |
|
118 | 117 | llm_service.init_record(session=session, current_user=current_user) |
119 | 118 |
|
120 | 119 | def run_task(): |
121 | | - sql_res = llm_service.generate_sql(session=session) |
122 | | - for chunk in sql_res: |
123 | | - yield json.dumps({'content': chunk, 'type': 'sql'}) + '\n\n' |
124 | | - yield json.dumps({'type': 'info', 'msg': 'sql generated'}) + '\n\n' |
125 | | - |
126 | | - # async def event_generator(): |
127 | | - # all_text = '' |
128 | | - # try: |
129 | | - # async for chunk in llm_service.async_generate(question, request_question.db_schema): |
130 | | - # data = json.loads(chunk.replace('data: ', '')) |
131 | | - # |
132 | | - # if data['type'] in ['final', 'tool_result']: |
133 | | - # content = data['content'] |
134 | | - # print('-- ' + content) |
135 | | - # all_text += content |
136 | | - # for char in content: |
137 | | - # yield f"data: {json.dumps({'type': 'char', 'content': char})}\n\n" |
138 | | - # await asyncio.sleep(0.05) |
139 | | - # |
140 | | - # if 'html' in data: |
141 | | - # yield f"data: {json.dumps({'type': 'html', 'content': data['html']})}\n\n" |
142 | | - # else: |
143 | | - # yield chunk |
144 | | - # |
145 | | - # except Exception as e: |
146 | | - # all_text = 'Exception:' + str(e) |
147 | | - # yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" |
148 | | - # |
149 | | - # try: |
150 | | - # save_answer(session=session, id=record.id, answer=all_text) |
151 | | - # except Exception as e: |
152 | | - # raise HTTPException( |
153 | | - # status_code=500, |
154 | | - # detail=str(e) |
155 | | - # ) |
156 | | - |
157 | | - # return EventSourceResponse(event_generator(), headers={"Content-Type": "text/event-stream"}) |
| 120 | + try: |
| 121 | + # generate sql |
| 122 | + sql_res = llm_service.generate_sql(session=session) |
| 123 | + full_sql_text = '' |
| 124 | + for chunk in sql_res: |
| 125 | + full_sql_text += chunk |
| 126 | + yield json.dumps({'content': chunk, 'type': 'sql-result'}, ensure_ascii=False) + '\n\n' |
| 127 | + yield json.dumps({'type': 'info', 'msg': 'sql generated'}) + '\n\n' |
| 128 | + |
| 129 | + # filter sql |
| 130 | + print(full_sql_text) |
| 131 | + sql = llm_service.check_save_sql(session=session, res=full_sql_text) |
| 132 | + print(sql) |
| 133 | + yield json.dumps({'content': sql, 'type': 'sql'}) + '\n\n' |
| 134 | + |
| 135 | + # execute sql |
| 136 | + result = llm_service.execute_sql(sql=sql) |
| 137 | + llm_service.save_sql_data(session=session, data_obj=result) |
| 138 | + yield json.dumps({'content': result, 'type': 'sql-data'}, ensure_ascii=False) + '\n\n' |
| 139 | + |
| 140 | + # generate chart |
| 141 | + chart_res = llm_service.generate_chart(session=session) |
| 142 | + full_chart_text = '' |
| 143 | + for chunk in chart_res: |
| 144 | + full_chart_text += chunk |
| 145 | + yield json.dumps({'content': chunk, 'type': 'chart-result'}, ensure_ascii=False) + '\n\n' |
| 146 | + yield json.dumps({'type': 'info', 'msg': 'chart generated'}) + '\n\n' |
| 147 | + |
| 148 | + # filter chart |
| 149 | + print(full_chart_text) |
| 150 | + chart = llm_service.check_save_chart(session=session, res=full_chart_text) |
| 151 | + print(chart) |
| 152 | + yield json.dumps({'content': chart, 'type': 'chart'}, ensure_ascii=False) + '\n\n' |
| 153 | + |
| 154 | + llm_service.finish(session=session) |
| 155 | + yield json.dumps({'type': 'finish'}) |
| 156 | + |
| 157 | + except Exception as e: |
| 158 | + traceback.print_exc() |
| 159 | + llm_service.save_error(session=session, message=str(e)) |
| 160 | + yield json.dumps({'content': str(e), 'type': 'error'}, ensure_ascii=False) + '\n\n' |
| 161 | + |
158 | 162 | return StreamingResponse(run_task(), media_type="text/event-stream") |
0 commit comments