Skip to content

Commit b820cb0

Browse files
committed
perf: upgrade fastapi-amis-admin to v0.4.0
1 parent 72899bd commit b820cb0

File tree

14 files changed

+108
-56
lines changed

14 files changed

+108
-56
lines changed

fastapi_user_auth/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.3.0"
1+
__version__ = "0.4.0a1"
22
__url__ = "https://github.com/amisadmin/fastapi_user_auth"
33

44
import gettext

fastapi_user_auth/admin.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
from typing import Any, Callable, Dict, List, Type
23

34
from fastapi import Depends, HTTPException
@@ -14,10 +15,11 @@
1415
PageSchema,
1516
)
1617
from fastapi_amis_admin.amis.constants import DisplayModeEnum, LevelEnum
18+
from fastapi_amis_admin.crud.base import SchemaUpdateT
1719
from fastapi_amis_admin.crud.schema import BaseApiOut
1820
from fastapi_amis_admin.utils.translation import i18n as _
1921
from pydantic import BaseModel
20-
from sqlalchemy import insert, select, update
22+
from sqlalchemy import select
2123
from starlette import status
2224
from starlette.requests import Request
2325
from starlette.responses import Response
@@ -46,14 +48,12 @@ class UserLoginFormAdmin(FormAdmin):
4648
page = Page(title=_("User Login"))
4749
page_path = "/login"
4850
page_parser_mode = "html"
49-
schema: Type[BaseModel] = None
51+
schema: Type[SchemaUpdateT] = None
5052
schema_submit_out: Type[UserLoginOut] = None
5153
page_schema = None
5254
page_route_kwargs = {"name": "login"}
5355

54-
async def handle(
55-
self, request: Request, data: BaseModel, **kwargs # self.schema
56-
) -> BaseApiOut[BaseModel]: # self.schema_submit_out
56+
async def handle(self, request: Request, data: SchemaUpdateT, **kwargs) -> BaseApiOut[BaseModel]: # self.schema_submit_out
5757
if request.user:
5858
return BaseApiOut(code=1, msg=_("User logged in!"), data=self.schema_submit_out.parse_obj(request.user))
5959
user = await request.auth.authenticate_user(username=data.username, password=data.password) # type:ignore
@@ -79,16 +79,14 @@ async def route(response: Response, result: BaseApiOut = Depends(super().route_s
7979
async def get_form(self, request: Request) -> Form:
8080
form = await super().get_form(request)
8181
buttons = []
82-
try:
82+
with contextlib.suppress(NoMatchFound):
8383
buttons.append(
8484
ActionType.Link(
8585
actionType="link",
86-
link=f"{self.router_path}{self.router.url_path_for('reg')}",
86+
link=f"{self.site.router_path}{self.app.router.url_path_for('reg')}",
8787
label=_("Sign up"),
8888
)
8989
)
90-
except NoMatchFound:
91-
pass
9290
buttons.append(Action(actionType="submit", label=_("Sign in"), level=LevelEnum.primary))
9391
form.body.sort(key=lambda form_item: form_item.type, reverse=True)
9492
form.update_from_kwargs(
@@ -130,27 +128,25 @@ class UserRegFormAdmin(FormAdmin):
130128
page = Page(title=_("User Register"))
131129
page_path = "/reg"
132130
page_parser_mode = "html"
133-
schema: Type[BaseModel] = None
131+
schema: Type[SchemaUpdateT] = None
134132
schema_submit_out: Type[UserLoginOut] = None
135133
page_schema = None
136134
page_route_kwargs = {"name": "reg"}
137135

138-
async def handle(
139-
self, request: Request, data: BaseModel, **kwargs # self.schema
140-
) -> BaseApiOut[BaseModel]: # self.schema_submit_out
136+
async def handle(self, request: Request, data: SchemaUpdateT, **kwargs) -> BaseApiOut[BaseModel]: # self.schema_submit_out
141137
auth: Auth = request.auth
142-
user = await auth.db.scalar(select(self.user_model).where(self.user_model.username == data.username))
138+
user = await auth.db.async_scalar(select(self.user_model).where(self.user_model.username == data.username))
143139
if user:
144140
return BaseApiOut(status=-1, msg=_("Username has been registered!"), data=None)
145-
user = await auth.db.scalar(select(self.user_model).where(self.user_model.email == data.email))
141+
user = await auth.db.async_scalar(select(self.user_model).where(self.user_model.email == data.email))
146142
if user:
147143
return BaseApiOut(status=-2, msg=_("Email has been registered!"), data=None)
148-
user = self.user_model.parse_obj(data)
149-
values = user.dict(exclude={"id", "password"})
150-
values["password"] = auth.pwd_context.hash(user.password.get_secret_value()) # 密码hash保存
151-
stmt = insert(self.user_model).values(values)
144+
values = data.dict(exclude={"id", "password"})
145+
values["password"] = auth.pwd_context.hash(data.password.get_secret_value()) # 密码hash保存
146+
user = self.user_model.parse_obj(values)
152147
try:
153-
user.id = await auth.db.async_execute(stmt, on_close_pre=lambda r: getattr(r, "lastrowid", None))
148+
auth.db.add(user)
149+
await auth.db.async_flush()
154150
except Exception as e:
155151
raise HTTPException(
156152
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -211,7 +207,7 @@ class UserInfoFormAdmin(FormAdmin):
211207
user_model: Type[BaseUser] = User
212208
page = Page(title=_("User Profile"))
213209
page_path = "/userinfo"
214-
schema: Type[BaseModel] = None
210+
schema: Type[SchemaUpdateT] = None
215211
schema_submit_out: Type[BaseUser] = None
216212
form_init = True
217213
form = Form(mode=DisplayModeEnum.horizontal)
@@ -230,10 +226,9 @@ async def get_form(self, request: Request) -> Form:
230226
form.body.extend(formitem.update_from_kwargs(disabled=True) for formitem in formitems if formitem)
231227
return form
232228

233-
async def handle(self, request: Request, data: BaseModel, **kwargs) -> BaseApiOut[Any]:
234-
stmt = update(self.user_model).where(self.user_model.username == request.user.username).values(data.dict())
235-
await self.site.db.async_execute(stmt)
236-
await self.site.db.async_refresh(request.user)
229+
async def handle(self, request: Request, data: SchemaUpdateT, **kwargs) -> BaseApiOut[Any]:
230+
for k, v in data.dict().items():
231+
setattr(request.user, k, v)
237232
return BaseApiOut(data=self.schema_submit_out.parse_obj(request.user))
238233

239234
async def has_page_permission(self, request: Request) -> bool:

fastapi_user_auth/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, app: "AdminApp"):
4545
self.UserInfoFormAdmin.schema = self.UserInfoFormAdmin.schema or schema_create_by_schema(
4646
self.auth.user_model,
4747
"UserInfoForm",
48-
exclude={"id", "username", "password", "is_active", "parent_id", "point", "create_time"},
48+
exclude={"id", "username", "password", "is_active", "create_time"},
4949
)
5050
self.UserInfoFormAdmin.schema_submit_out = self.UserInfoFormAdmin.schema_submit_out or self.schema_user_info
5151
# register admin

fastapi_user_auth/auth/auth.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from fastapi_amis_admin.utils.translation import i18n as _
2626
from passlib.context import CryptContext
2727
from pydantic import BaseModel, SecretStr
28-
from sqlalchemy.ext.asyncio import AsyncSession
2928
from sqlalchemy.orm import Session
3029
from sqlalchemy_database import AsyncDatabase, Database
3130
from sqlmodel import select
@@ -52,9 +51,7 @@ def __init__(self, auth: "Auth", token_store: BaseTokenStore):
5251
def get_user_token(request: Request) -> Optional[str]:
5352
authorization: str = request.headers.get("Authorization") or request.cookies.get("Authorization")
5453
scheme, token = get_authorization_scheme_param(authorization)
55-
if not authorization or scheme.lower() != "bearer":
56-
return None
57-
return token
54+
return None if not authorization or scheme.lower() != "bearer" else token
5855

5956
async def authenticate(self, request: Request) -> Tuple["Auth", Optional[_UserModelT]]:
6057
return self.auth, await self.auth.get_current_user(request)
@@ -92,9 +89,7 @@ async def authenticate_user(self, username: str, password: Union[str, SecretStr]
9289

9390
@cached_property
9491
def get_current_user(self):
95-
async def _get_current_user(
96-
request: Request, session: Union[Session, AsyncSession, None] = Depends(self.db.session_generator)
97-
) -> Optional[_UserModelT]:
92+
async def _get_current_user(request: Request) -> Optional[_UserModelT]:
9893
if request.scope.get("auth"): # 防止重复授权
9994
return request.scope.get("user")
10095
request.scope["auth"], request.scope["user"] = self, None
@@ -103,7 +98,7 @@ async def _get_current_user(
10398
return None
10499
token_data = await self.backend.token_store.read_token(token)
105100
if token_data is not None:
106-
request.scope["user"]: _UserModelT = await self.db.async_get(self.user_model, token_data.id, session=session)
101+
request.scope["user"]: _UserModelT = await self.db.async_get(self.user_model, token_data.id)
107102
return request.user
108103

109104
return _get_current_user
@@ -117,18 +112,27 @@ def requires(
117112
redirect: str = None,
118113
response: Union[bool, Response] = None,
119114
) -> Callable: # sourcery no-metrics
115+
groups_ = (groups,) if not groups or isinstance(groups, str) else tuple(groups)
116+
roles_ = (roles,) if not roles or isinstance(roles, str) else tuple(roles)
117+
permissions_ = (permissions,) if not permissions or isinstance(permissions, str) else tuple(permissions)
118+
120119
async def has_requires(user: _UserModelT) -> bool:
121-
return user and await self.db.async_run_sync(
122-
user.has_requires, roles=roles, groups=groups, permissions=permissions, commit=False
123-
)
120+
return user and await self.db.async_run_sync(user.has_requires, roles=roles, groups=groups, permissions=permissions)
124121

125122
async def depend(
126123
request: Request,
127124
user: _UserModelT = Depends(self.get_current_user),
128125
) -> Union[bool, Response]:
129-
if isinstance(user, params.Depends):
130-
user = await self.get_current_user(request)
131-
if not await has_requires(user):
126+
user_auth = request.scope.get("__user_auth__", None)
127+
if user_auth is None:
128+
request.scope["__user_auth__"] = {}
129+
cache_key = (groups_, roles_, permissions_)
130+
if cache_key not in request.scope["__user_auth__"]: # 防止重复授权
131+
if isinstance(user, params.Depends):
132+
user = await self.get_current_user(request)
133+
result = await has_requires(user)
134+
request.scope["__user_auth__"][cache_key] = result
135+
if not request.scope["__user_auth__"][cache_key]:
132136
if response is not None:
133137
return response
134138
code, headers = status_code, {}
@@ -145,7 +149,7 @@ def decorator(func: Callable = None) -> Union[Callable, Coroutine]:
145149
return depend(func)
146150
sig = inspect.signature(func)
147151
for idx, parameter in enumerate(sig.parameters.values()): # noqa: B007
148-
if parameter.name == "request" or parameter.name == "websocket":
152+
if parameter.name in ["request", "websocket"]:
149153
type_ = parameter.name
150154
break
151155
else:
@@ -218,11 +222,10 @@ def _create_role_user_sync(self, session: Session, role_key: str = "admin") -> U
218222
)
219223
session.add(user)
220224
session.flush()
221-
222225
return user
223226

224227
async def create_role_user(self, role_key: str = "admin") -> User:
225-
return await self.db.async_run_sync(self._create_role_user_sync, role_key, on_close_pre=lambda user: User.parse_obj(user))
228+
return await self.db.async_run_sync(self._create_role_user_sync, role_key)
226229

227230

228231
class AuthRouter(RouterMixin):
@@ -249,7 +252,12 @@ def __init__(self, auth: Auth = None):
249252
response_model=BaseApiOut[self.schema_user_info],
250253
)
251254
self.router.add_api_route(
252-
"/logout", self.route_logout, methods=["GET"], description=_("Sign out"), dependencies=None, response_model=BaseApiOut
255+
"/logout",
256+
self.route_logout,
257+
methods=["GET"],
258+
description=_("Sign out"),
259+
dependencies=None,
260+
response_model=BaseApiOut,
253261
)
254262
# oauth2
255263
self.router.dependencies.append(Depends(self.OAuth2(tokenUrl=f"{self.router_path}/gettoken", auto_error=False)))

fastapi_user_auth/auth/backends/db.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import datetime, timedelta
33
from typing import Optional, Union
44

5-
from sqlalchemy import Column, String, delete, insert
5+
from sqlalchemy import Column, String, delete
66
from sqlalchemy_database import AsyncDatabase, Database
77
from sqlmodel import Field, select
88

@@ -32,18 +32,20 @@ async def read_token(self, token: str) -> Optional[_TokenDataSchemaT]:
3232
if obj is None:
3333
return None
3434
# expire
35-
if obj.create_time < datetime.utcnow() - timedelta(seconds=self.expire_seconds):
35+
if obj.create_time < datetime.now() - timedelta(seconds=self.expire_seconds):
3636
await self.destroy_token(token=token)
3737
return None
3838
return self.TokenDataSchema.parse_raw(obj.data)
3939

4040
async def write_token(self, token_data: Union[_TokenDataSchemaT, dict]) -> str:
4141
obj = self.TokenDataSchema.parse_obj(token_data) if isinstance(token_data, dict) else token_data
4242
token = secrets.token_urlsafe()
43-
stmt = insert(TokenStoreModel).values(dict(token=token, data=obj.json()))
44-
await self.db.async_execute(stmt)
43+
model = TokenStoreModel(token=token, data=obj.json())
44+
self.db.add(model)
45+
await self.db.async_flush()
4546
return token
4647

4748
async def destroy_token(self, token: str) -> None:
4849
stmt = delete(TokenStoreModel).where(TokenStoreModel.token == token)
4950
await self.db.async_execute(stmt)
51+
await self.db.async_flush()

fastapi_user_auth/auth/backends/jwt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def read_token(self, token: str) -> Optional[_TokenDataSchemaT]:
2828
async def write_token(self, token_data: Union[_TokenDataSchemaT, dict]) -> str:
2929
obj = self.TokenDataSchema.parse_obj(token_data) if isinstance(token_data, dict) else token_data
3030
data = obj.dict()
31-
expire = datetime.utcnow() + timedelta(seconds=self.expire_seconds)
31+
expire = datetime.now() + timedelta(seconds=self.expire_seconds)
3232
data.update({"exp": expire})
3333
return jwt.encode(data, self.secret_key, algorithm=self.algorithm)
3434

fastapi_user_auth/auth/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ def has_requires(
241241
Returns:
242242
检测成功返回`True`
243243
"""
244+
if not groups and not roles and not permissions:
245+
return True
244246
stmt = select(1)
245247
if groups:
246248
groups_list = [groups] if isinstance(groups, str) else list(groups)

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@ classifiers = [
3737
]
3838
dependencies = [
3939
"pydantic>=1.9",
40-
"fastapi-amis-admin>=0.2.4,<0.4.0",
40+
"fastapi-amis-admin>=0.4.0a1,<0.5.0",
4141
"email-validator",
4242
"passlib>=1.7.4",
4343
"bcrypt>=4.0.0",
4444
"sqlmodelx>=0.0.2",
45-
"sqlalchemy-database>=0.0.10",
4645
]
4746

4847
[project.urls]

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from sqlalchemy_database import AsyncDatabase, Database
23

34
# sqlite
@@ -17,3 +18,16 @@
1718

1819
# SQL Server
1920
# sync_db = Database.create('mssql+pyodbc://scott:tiger@mydsn')
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def _setup_sync_db() -> Database:
25+
yield sync_db
26+
# Free connection pool resources
27+
sync_db.close() # type: ignore
28+
29+
30+
@pytest.fixture(autouse=True)
31+
async def _setup_async_db() -> AsyncDatabase:
32+
yield async_db
33+
await async_db.async_close() # Free connection pool resources

tests/test_auth/conftest.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sqlalchemy.orm import Session
77
from sqlalchemy_database import AsyncDatabase, Database
88
from sqlmodel import SQLModel
9+
from starlette.middleware.base import BaseHTTPMiddleware
910
from starlette.testclient import TestClient
1011

1112
from fastapi_user_auth.auth.auth import Auth, AuthRouter
@@ -19,6 +20,7 @@ async def db(request) -> Union[Database, AsyncDatabase]:
1920
await database.async_run_sync(SQLModel.metadata.create_all, is_session=False)
2021
yield database
2122
await database.async_run_sync(SQLModel.metadata.drop_all, is_session=False)
23+
await database.async_close()
2224

2325

2426
@pytest.fixture()
@@ -35,7 +37,7 @@ def event_loop():
3537

3638
@pytest.fixture(scope="session")
3739
async def fake_auth() -> Auth:
38-
auth = Auth(db=AsyncDatabase.create("sqlite+aiosqlite:///amisadmin.db?check_same_thread=False"))
40+
auth = Auth(db=async_db)
3941

4042
# noinspection PyTypeChecker
4143
def create_fake_users(session: Session):
@@ -80,8 +82,10 @@ def create_fake_users(session: Session):
8082

8183
await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session=False)
8284
await auth.db.async_run_sync(create_fake_users)
85+
await auth.db.async_commit()
8386
yield auth
8487
await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session=False)
88+
await auth.db.async_close()
8589

8690

8791
class UserClient:
@@ -95,6 +99,7 @@ def __init__(self, auth: Auth, client: TestClient = None, user: User = None) ->
9599
@pytest.fixture(scope="session")
96100
def logins(request, fake_auth: Auth) -> UserClient:
97101
app = FastAPI()
102+
app.add_middleware(BaseHTTPMiddleware, dispatch=async_db.asgi_dispatch)
98103
# 注册auth基础路由
99104
auth_router = AuthRouter(auth=fake_auth)
100105
app.include_router(auth_router.router)
@@ -105,7 +110,6 @@ def logins(request, fake_auth: Auth) -> UserClient:
105110
"test": {"username": "test", "password": "test"},
106111
"guest": {"username": None, "password": None},
107112
}
108-
user = user_data.get(request.param) or {}
109113

110114
def get_login_client(username: str = None, password: str = None) -> UserClient:
111115
client = TestClient(app)
@@ -123,4 +127,4 @@ def get_login_client(username: str = None, password: str = None) -> UserClient:
123127
assert user.username == username
124128
return UserClient(fake_auth, client=client, user=user)
125129

126-
return get_login_client(**user)
130+
return get_login_client(**user_data.get(request.param, {}))

0 commit comments

Comments
 (0)