Skip to content

Commit b06891e

Browse files
committed
feat: execute sql & generate chart with history
1 parent cc19a84 commit b06891e

File tree

10 files changed

+290
-553
lines changed

10 files changed

+290
-553
lines changed

backend/apps/chat/api/chat.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import asyncio
21
import json
2+
import traceback
33
from typing import List
44

55
from fastapi import APIRouter, HTTPException
66
from fastapi.responses import StreamingResponse
77
from sqlmodel import select
88

9-
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, save_question, save_answer, rename_chat, \
9+
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
1010
delete_chat, list_records
1111
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat, ChatQuestion
1212
from apps.chat.task.llm import LLMService
@@ -79,7 +79,6 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
7979
Returns:
8080
Streaming response with analysis results
8181
"""
82-
question = request_question.question
8382

8483
chat = session.query(Chat).filter(Chat.id == request_question.chat_id).first()
8584
if not chat:
@@ -113,46 +112,51 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
113112
chart_id=request_question.chat_id)
114113
# get schema
115114
request_question.db_schema = get_table_schema(session=session, ds=ds)
116-
llm_service = LLMService(request_question, history_records, ds, aimodel)
115+
llm_service = LLMService(request_question, history_records, CoreDatasource(**ds.model_dump()), aimodel)
117116

118117
llm_service.init_record(session=session, current_user=current_user)
119118

120119
def run_task():
121-
sql_res = llm_service.generate_sql(session=session)
122-
for chunk in sql_res:
123-
yield json.dumps({'content': chunk, 'type': 'sql'}) + '\n\n'
124-
yield json.dumps({'type': 'info', 'msg': 'sql generated'}) + '\n\n'
125-
126-
# async def event_generator():
127-
# all_text = ''
128-
# try:
129-
# async for chunk in llm_service.async_generate(question, request_question.db_schema):
130-
# data = json.loads(chunk.replace('data: ', ''))
131-
#
132-
# if data['type'] in ['final', 'tool_result']:
133-
# content = data['content']
134-
# print('-- ' + content)
135-
# all_text += content
136-
# for char in content:
137-
# yield f"data: {json.dumps({'type': 'char', 'content': char})}\n\n"
138-
# await asyncio.sleep(0.05)
139-
#
140-
# if 'html' in data:
141-
# yield f"data: {json.dumps({'type': 'html', 'content': data['html']})}\n\n"
142-
# else:
143-
# yield chunk
144-
#
145-
# except Exception as e:
146-
# all_text = 'Exception:' + str(e)
147-
# yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
148-
#
149-
# try:
150-
# save_answer(session=session, id=record.id, answer=all_text)
151-
# except Exception as e:
152-
# raise HTTPException(
153-
# status_code=500,
154-
# detail=str(e)
155-
# )
156-
157-
# return EventSourceResponse(event_generator(), headers={"Content-Type": "text/event-stream"})
120+
try:
121+
# generate sql
122+
sql_res = llm_service.generate_sql(session=session)
123+
full_sql_text = ''
124+
for chunk in sql_res:
125+
full_sql_text += chunk
126+
yield json.dumps({'content': chunk, 'type': 'sql-result'}, ensure_ascii=False) + '\n\n'
127+
yield json.dumps({'type': 'info', 'msg': 'sql generated'}) + '\n\n'
128+
129+
# filter sql
130+
print(full_sql_text)
131+
sql = llm_service.check_save_sql(session=session, res=full_sql_text)
132+
print(sql)
133+
yield json.dumps({'content': sql, 'type': 'sql'}) + '\n\n'
134+
135+
# execute sql
136+
result = llm_service.execute_sql(sql=sql)
137+
llm_service.save_sql_data(session=session, data_obj=result)
138+
yield json.dumps({'content': result, 'type': 'sql-data'}, ensure_ascii=False) + '\n\n'
139+
140+
# generate chart
141+
chart_res = llm_service.generate_chart(session=session)
142+
full_chart_text = ''
143+
for chunk in chart_res:
144+
full_chart_text += chunk
145+
yield json.dumps({'content': chunk, 'type': 'chart-result'}, ensure_ascii=False) + '\n\n'
146+
yield json.dumps({'type': 'info', 'msg': 'chart generated'}) + '\n\n'
147+
148+
# filter chart
149+
print(full_chart_text)
150+
chart = llm_service.check_save_chart(session=session, res=full_chart_text)
151+
print(chart)
152+
yield json.dumps({'content': chart, 'type': 'chart'}, ensure_ascii=False) + '\n\n'
153+
154+
llm_service.finish(session=session)
155+
yield json.dumps({'type': 'finish'})
156+
157+
except Exception as e:
158+
traceback.print_exc()
159+
llm_service.save_error(session=session, message=str(e))
160+
yield json.dumps({'content': str(e), 'type': 'error'}, ensure_ascii=False) + '\n\n'
161+
158162
return StreamingResponse(run_task(), media_type="text/event-stream")

backend/apps/chat/curd/chat.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,35 @@ def save_full_sql_message_and_answer(session: SessionDep, record_id: int, answer
150150
return result
151151

152152

153-
def save_full_chart_message(session: SessionDep, id: int, full_message: str) -> ChatRecord:
154-
if not id:
153+
def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord:
154+
if not record_id:
155155
raise Exception("Record id cannot be None")
156-
record = session.query(ChatRecord).filter(ChatRecord.id == id).first()
156+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
157+
record.sql = sql
158+
159+
result = ChatRecord(**record.model_dump())
160+
161+
session.add(record)
162+
session.flush()
163+
session.refresh(record)
164+
165+
session.commit()
166+
167+
return result
168+
169+
170+
def save_full_chart_message(session: SessionDep, record_id: int, full_message: str) -> ChatRecord:
171+
return save_full_chart_message_and_answer(session=session, record_id=record_id, full_message=full_message,
172+
answer='')
173+
174+
175+
def save_full_chart_message_and_answer(session: SessionDep, record_id: int, answer: str,
176+
full_message: str) -> ChatRecord:
177+
if not record_id:
178+
raise Exception("Record id cannot be None")
179+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
157180
record.full_chart_message = full_message
181+
record.chart_answer = answer
158182

159183
result = ChatRecord(**record.model_dump())
160184

@@ -167,12 +191,62 @@ def save_full_chart_message(session: SessionDep, id: int, full_message: str) ->
167191
return result
168192

169193

170-
def save_answer(session: SessionDep, id: int, answer: str) -> ChatRecord:
171-
if not id:
194+
def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord:
195+
if not record_id:
172196
raise Exception("Record id cannot be None")
197+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
198+
record.chart = chart
199+
200+
result = ChatRecord(**record.model_dump())
201+
202+
session.add(record)
203+
session.flush()
204+
session.refresh(record)
205+
206+
session.commit()
207+
208+
return result
173209

174-
record = session.query(ChatRecord).filter(ChatRecord.id == id).first()
175-
record.answer = answer
210+
211+
def save_error_message(session: SessionDep, record_id: int, message: str) -> ChatRecord:
212+
if not record_id:
213+
raise Exception("Record id cannot be None")
214+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
215+
record.error = message
216+
record.finish = True
217+
218+
result = ChatRecord(**record.model_dump())
219+
220+
session.add(record)
221+
session.flush()
222+
session.refresh(record)
223+
224+
session.commit()
225+
226+
return result
227+
228+
229+
def save_sql_exec_data(session: SessionDep, record_id: int, data: str) -> ChatRecord:
230+
if not record_id:
231+
raise Exception("Record id cannot be None")
232+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
233+
record.data = data
234+
235+
result = ChatRecord(**record.model_dump())
236+
237+
session.add(record)
238+
session.flush()
239+
session.refresh(record)
240+
241+
session.commit()
242+
243+
return result
244+
245+
def finish_record(session: SessionDep, record_id: int) -> ChatRecord:
246+
if not record_id:
247+
raise Exception("Record id cannot be None")
248+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
249+
record.finish = True
176250

177251
result = ChatRecord(**record.model_dump())
178252

@@ -183,3 +257,4 @@ def save_answer(session: SessionDep, id: int, answer: str) -> ChatRecord:
183257
session.commit()
184258

185259
return result
260+

backend/apps/chat/schemas/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)