1
+ import contextlib
1
2
from typing import Any , Callable , Dict , List , Type
2
3
3
4
from fastapi import Depends , HTTPException
14
15
PageSchema ,
15
16
)
16
17
from fastapi_amis_admin .amis .constants import DisplayModeEnum , LevelEnum
18
+ from fastapi_amis_admin .crud .base import SchemaUpdateT
17
19
from fastapi_amis_admin .crud .schema import BaseApiOut
18
20
from fastapi_amis_admin .utils .translation import i18n as _
19
21
from pydantic import BaseModel
20
- from sqlalchemy import insert , select , update
22
+ from sqlalchemy import select
21
23
from starlette import status
22
24
from starlette .requests import Request
23
25
from starlette .responses import Response
@@ -46,14 +48,12 @@ class UserLoginFormAdmin(FormAdmin):
46
48
page = Page (title = _ ("User Login" ))
47
49
page_path = "/login"
48
50
page_parser_mode = "html"
49
- schema : Type [BaseModel ] = None
51
+ schema : Type [SchemaUpdateT ] = None
50
52
schema_submit_out : Type [UserLoginOut ] = None
51
53
page_schema = None
52
54
page_route_kwargs = {"name" : "login" }
53
55
54
- async def handle (
55
- self , request : Request , data : BaseModel , ** kwargs # self.schema
56
- ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
56
+ async def handle (self , request : Request , data : SchemaUpdateT , ** kwargs ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
57
57
if request .user :
58
58
return BaseApiOut (code = 1 , msg = _ ("User logged in!" ), data = self .schema_submit_out .parse_obj (request .user ))
59
59
user = await request .auth .authenticate_user (username = data .username , password = data .password ) # type:ignore
@@ -79,16 +79,14 @@ async def route(response: Response, result: BaseApiOut = Depends(super().route_s
79
79
async def get_form (self , request : Request ) -> Form :
80
80
form = await super ().get_form (request )
81
81
buttons = []
82
- try :
82
+ with contextlib . suppress ( NoMatchFound ) :
83
83
buttons .append (
84
84
ActionType .Link (
85
85
actionType = "link" ,
86
- link = f"{ self .router_path } { self .router .url_path_for ('reg' )} " ,
86
+ link = f"{ self .site . router_path } { self . app .router .url_path_for ('reg' )} " ,
87
87
label = _ ("Sign up" ),
88
88
)
89
89
)
90
- except NoMatchFound :
91
- pass
92
90
buttons .append (Action (actionType = "submit" , label = _ ("Sign in" ), level = LevelEnum .primary ))
93
91
form .body .sort (key = lambda form_item : form_item .type , reverse = True )
94
92
form .update_from_kwargs (
@@ -130,27 +128,25 @@ class UserRegFormAdmin(FormAdmin):
130
128
page = Page (title = _ ("User Register" ))
131
129
page_path = "/reg"
132
130
page_parser_mode = "html"
133
- schema : Type [BaseModel ] = None
131
+ schema : Type [SchemaUpdateT ] = None
134
132
schema_submit_out : Type [UserLoginOut ] = None
135
133
page_schema = None
136
134
page_route_kwargs = {"name" : "reg" }
137
135
138
- async def handle (
139
- self , request : Request , data : BaseModel , ** kwargs # self.schema
140
- ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
136
+ async def handle (self , request : Request , data : SchemaUpdateT , ** kwargs ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
141
137
auth : Auth = request .auth
142
- user = await auth .db .scalar (select (self .user_model ).where (self .user_model .username == data .username ))
138
+ user = await auth .db .async_scalar (select (self .user_model ).where (self .user_model .username == data .username ))
143
139
if user :
144
140
return BaseApiOut (status = - 1 , msg = _ ("Username has been registered!" ), data = None )
145
- user = await auth .db .scalar (select (self .user_model ).where (self .user_model .email == data .email ))
141
+ user = await auth .db .async_scalar (select (self .user_model ).where (self .user_model .email == data .email ))
146
142
if user :
147
143
return BaseApiOut (status = - 2 , msg = _ ("Email has been registered!" ), data = None )
148
- user = self .user_model .parse_obj (data )
149
- values = user .dict (exclude = {"id" , "password" })
150
- values ["password" ] = auth .pwd_context .hash (user .password .get_secret_value ()) # 密码hash保存
151
- stmt = insert (self .user_model ).values (values )
144
+ values = data .dict (exclude = {"id" , "password" })
145
+ values ["password" ] = auth .pwd_context .hash (data .password .get_secret_value ()) # 密码hash保存
146
+ user = self .user_model .parse_obj (values )
152
147
try :
153
- user .id = await auth .db .async_execute (stmt , on_close_pre = lambda r : getattr (r , "lastrowid" , None ))
148
+ auth .db .add (user )
149
+ await auth .db .async_flush ()
154
150
except Exception as e :
155
151
raise HTTPException (
156
152
status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
@@ -211,7 +207,7 @@ class UserInfoFormAdmin(FormAdmin):
211
207
user_model : Type [BaseUser ] = User
212
208
page = Page (title = _ ("User Profile" ))
213
209
page_path = "/userinfo"
214
- schema : Type [BaseModel ] = None
210
+ schema : Type [SchemaUpdateT ] = None
215
211
schema_submit_out : Type [BaseUser ] = None
216
212
form_init = True
217
213
form = Form (mode = DisplayModeEnum .horizontal )
@@ -230,10 +226,9 @@ async def get_form(self, request: Request) -> Form:
230
226
form .body .extend (formitem .update_from_kwargs (disabled = True ) for formitem in formitems if formitem )
231
227
return form
232
228
233
- async def handle (self , request : Request , data : BaseModel , ** kwargs ) -> BaseApiOut [Any ]:
234
- stmt = update (self .user_model ).where (self .user_model .username == request .user .username ).values (data .dict ())
235
- await self .site .db .async_execute (stmt )
236
- await self .site .db .async_refresh (request .user )
229
+ async def handle (self , request : Request , data : SchemaUpdateT , ** kwargs ) -> BaseApiOut [Any ]:
230
+ for k , v in data .dict ().items ():
231
+ setattr (request .user , k , v )
237
232
return BaseApiOut (data = self .schema_submit_out .parse_obj (request .user ))
238
233
239
234
async def has_page_permission (self , request : Request ) -> bool :
0 commit comments