Skip to content

Commit 80484c3

Browse files
perf: Support not only pgsql but mysql
1 parent 6a9b5ea commit 80484c3

File tree

6 files changed

+64
-65
lines changed

6 files changed

+64
-65
lines changed

backend/alembic/versions/010_upgrade_user_language.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def upgrade():
2121
# ### commands auto generated by Alembic - please adjust! ###
2222

23-
op.add_column('sys_user', sa.Column('language', sa.VARCHAR(length=255), server_default=sa.text("'zh-CN'::character varying"), nullable=False))
23+
op.add_column('sys_user', sa.Column('language', sa.VARCHAR(length=255), server_default=sa.text("'zh-CN'"), nullable=False))
2424

2525
# ### end Alembic commands ###
2626

backend/alembic/versions/018_modify_chat.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def upgrade():
2323
existing_type=sa.INTEGER(),
2424
type_=sa.BigInteger(),
2525
existing_nullable=False,
26-
autoincrement=True,
27-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
26+
autoincrement=True
27+
)
2828
op.alter_column('chat', 'datasource',
2929
existing_type=sa.INTEGER(),
3030
type_=sa.BigInteger(),
@@ -33,8 +33,7 @@ def upgrade():
3333
existing_type=sa.INTEGER(),
3434
type_=sa.BigInteger(),
3535
existing_nullable=False,
36-
autoincrement=True,
37-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
36+
autoincrement=True)
3837
op.alter_column('chat_record', 'chat_id',
3938
existing_type=sa.INTEGER(),
4039
type_=sa.BigInteger(),
@@ -68,8 +67,7 @@ def downgrade():
6867
existing_type=sa.BigInteger(),
6968
type_=sa.INTEGER(),
7069
existing_nullable=False,
71-
autoincrement=True,
72-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
70+
autoincrement=True)
7371
op.alter_column('chat', 'datasource',
7472
existing_type=sa.BigInteger(),
7573
type_=sa.INTEGER(),
@@ -78,6 +76,5 @@ def downgrade():
7876
existing_type=sa.BigInteger(),
7977
type_=sa.INTEGER(),
8078
existing_nullable=False,
81-
autoincrement=True,
82-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1, maxvalue=2147483647, cycle=False, cache=1))
79+
autoincrement=True)
8380
# ### end Alembic commands ###

backend/alembic/versions/030_permission_oid.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ def upgrade():
2424
existing_type=sa.INTEGER(),
2525
type_=sa.BigInteger(),
2626
existing_nullable=False,
27-
autoincrement=True,
28-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1,
29-
maxvalue=2147483647, cycle=False, cache=1))
27+
autoincrement=True)
3028
# ### end Alembic commands ###
3129

3230

@@ -36,8 +34,6 @@ def downgrade():
3634
existing_type=sa.BigInteger(),
3735
type_=sa.INTEGER(),
3836
existing_nullable=False,
39-
autoincrement=True,
40-
existing_server_default=sa.Identity(always=True, start=1, increment=1, minvalue=1,
41-
maxvalue=2147483647, cycle=False, cache=1))
37+
autoincrement=True)
4238
op.drop_column('ds_rules', 'oid')
4339
# ### end Alembic commands ###

backend/apps/system/api/user.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import defaultdict
12
from typing import Optional
23
from fastapi import APIRouter, HTTPException, Query
3-
from sqlmodel import func, or_, select, delete as sqlmodel_delete
4+
from sqlmodel import SQLModel, func, or_, select, delete as sqlmodel_delete
45
from apps.system.crud.user import check_account_exists, check_email_exists, check_email_format, check_pwd_format, get_db_user, single_delete, user_ws_options
56
from apps.system.models.system_model import UserWsModel, WorkspaceModel
67
from apps.system.models.user import UserModel
@@ -11,6 +12,7 @@
1112
from common.core.schemas import PaginatedResponse, PaginationParams
1213
from common.core.security import default_md5_pwd, md5pwd, verify_md5pwd
1314
from common.core.sqlbot_cache import clear_cache
15+
1416
router = APIRouter(tags=["user"], prefix="/user")
1517

1618
@router.get("/info")
@@ -30,43 +32,15 @@ async def pager(
3032
pagination = PaginationParams(page=pageNum, size=pageSize)
3133
paginator = Paginator(session)
3234
filters = {}
33-
34-
stmt = (
35-
select(
36-
UserModel,
37-
func.coalesce(
38-
func.array_remove(
39-
func.array_agg(UserWsModel.oid),
40-
None
41-
),
42-
[]
43-
).label("oid_list")
44-
#func.coalesce(func.string_agg(WorkspaceModel.name, ','), '').label("space_name")
45-
)
46-
.join(UserWsModel, UserModel.id == UserWsModel.uid, isouter=True)
47-
#.join(WorkspaceModel, UserWsModel.oid == WorkspaceModel.id, isouter=True)
48-
.where(UserModel.id != 1)
49-
.group_by(UserModel.id)
50-
.order_by(UserModel.create_time)
51-
)
52-
if status is not None:
53-
stmt = stmt.where(UserModel.status == status)
5435

36+
origin_stmt = select(UserModel.id).join(UserWsModel, UserModel.id == UserWsModel.uid).where(UserModel.id != 1).distinct()
5537
if oidlist:
56-
user_filter = (
57-
select(UserModel.id)
58-
.join(UserWsModel, UserModel.id == UserWsModel.uid)
59-
.where(UserWsModel.oid.in_(oidlist))
60-
.distinct()
61-
)
62-
stmt = stmt.where(UserModel.id.in_(user_filter))
63-
64-
""" if origins is not None:
65-
stmt = stmt.where(UserModel.origin == origins) """
66-
38+
origin_stmt = origin_stmt.where(UserWsModel.oid.in_(oidlist))
39+
if status is not None:
40+
origin_stmt = origin_stmt.where(UserModel.status == status)
6741
if keyword:
6842
keyword_pattern = f"%{keyword}%"
69-
stmt = stmt.where(
43+
origin_stmt = origin_stmt.where(
7044
or_(
7145
UserModel.account.ilike(keyword_pattern),
7246
UserModel.name.ilike(keyword_pattern),
@@ -75,21 +49,47 @@ async def pager(
7549
)
7650

7751
user_page = await paginator.get_paginated_response(
78-
stmt=stmt,
52+
stmt=origin_stmt,
7953
pagination=pagination,
8054
**filters)
81-
82-
""" for item in user_page.items:
83-
space_name: str = item['space_name']
84-
if space_name and 'i18n_default_workspace' in space_name:
85-
parts = list(map(
86-
lambda x: trans(x) if x == "i18n_default_workspace" else x,
87-
space_name.split(',')
88-
))
89-
output_str = ','.join(parts)
90-
item['space_name'] = output_str """
55+
uid_list = [item.get('id') for item in user_page.items]
56+
if not uid_list:
57+
return user_page
58+
stmt = (
59+
select(UserModel, UserWsModel.oid.label('ws_oid'))
60+
.join(UserWsModel, UserModel.id == UserWsModel.uid, isouter=True)
61+
.where(UserModel.id.in_(uid_list))
62+
.order_by(UserModel.create_time)
63+
)
64+
user_workspaces = session.exec(stmt).all()
65+
merged = defaultdict(list)
66+
extra_attrs = {}
67+
68+
for (user, ws_oid) in user_workspaces:
69+
item = {}
70+
item.update(user.model_dump())
71+
user_id = item['id']
72+
merged[user_id].append(ws_oid)
73+
if user_id not in extra_attrs:
74+
extra_attrs[user_id] = {k: v for k, v in item.items() if k != "ws_oid"}
75+
76+
# 组合结果
77+
result = [
78+
{**extra_attrs[user_id], "oid_list": oid_list}
79+
for user_id, oid_list in merged.items()
80+
]
81+
user_page.items = result
9182
return user_page
9283

84+
def format_user_dict(row) -> dict:
85+
result_dict = {}
86+
for item, key in zip(row, row._fields):
87+
if isinstance(item, SQLModel):
88+
result_dict.update(item.model_dump())
89+
else:
90+
result_dict[key] = item
91+
92+
return result_dict
9393
@router.get("/ws")
9494
async def ws_options(session: SessionDep, current_user: CurrentUser, trans: Trans) -> list[UserWs]:
9595
return await user_ws_options(session, current_user.id, trans)

backend/common/core/config.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ def all_cors_origins(self) -> list[str]:
4444
self.FRONTEND_HOST
4545
]
4646

47-
POSTGRES_SERVER: str
47+
POSTGRES_SERVER: str = 'localhost'
4848
POSTGRES_PORT: int = 5432
49-
POSTGRES_USER: str
50-
POSTGRES_PASSWORD: str = ""
51-
POSTGRES_DB: str = ""
49+
POSTGRES_USER: str = 'root'
50+
POSTGRES_PASSWORD: str = "123456"
51+
POSTGRES_DB: str = "sqlbot"
52+
SQLBOT_DB_URL: str = ''
53+
#SQLBOT_DB_URL: str = 'mysql+pymysql://root:Password123%40mysql@127.0.0.1:3306/sqlbot'
5254

5355
TOKEN_KEY: str = "X-SQLBOT-TOKEN"
5456
DEFAULT_PWD: str = "SQLBot@123456"
@@ -64,7 +66,9 @@ def all_cors_origins(self) -> list[str]:
6466

6567
@computed_field # type: ignore[prop-decorator]
6668
@property
67-
def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
69+
def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
70+
if self.SQLBOT_DB_URL:
71+
return self.SQLBOT_DB_URL
6872
return MultiHostUrl.build(
6973
scheme="postgresql+psycopg",
7074
username=self.POSTGRES_USER,

backend/common/core/pagination.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def __init__(self, session: Session):
1212
self.session = session
1313
def _process_result_row(self, row: Row) -> Dict[str, Any]:
1414
result_dict = {}
15+
if isinstance(row, int):
16+
return {'id': row}
1517
for item, key in zip(row, row._fields):
1618
if isinstance(item, SQLModel):
1719
result_dict.update(item.model_dump())

0 commit comments

Comments
 (0)