402 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			402 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| FastAPI user handling (via fastapi-users)
 | |
| """
 | |
| 
 | |
| import os
 | |
| import uuid
 | |
| import asyncio
 | |
| 
 | |
| from typing import Dict, Optional
 | |
| 
 | |
| from pydantic import EmailStr, UUID4
 | |
| import passlib.pwd
 | |
| 
 | |
| from fastapi import Request, Response, HTTPException, Depends, WebSocket
 | |
| from fastapi.security import OAuth2PasswordBearer
 | |
| 
 | |
| from pymongo.errors import DuplicateKeyError
 | |
| 
 | |
| from fastapi_users import FastAPIUsers, models, BaseUserManager
 | |
| from fastapi_users.manager import UserAlreadyExists
 | |
| from fastapi_users.authentication import (
 | |
|     AuthenticationBackend,
 | |
|     BearerTransport,
 | |
|     JWTStrategy,
 | |
| )
 | |
| from fastapi_users.db import MongoDBUserDatabase
 | |
| 
 | |
| from .invites import InvitePending, InviteRequest
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| PASSWORD_SECRET = os.environ.get("PASSWORD_SECRET", uuid.uuid4().hex)
 | |
| 
 | |
| JWT_TOKEN_LIFETIME = int(os.environ.get("JWT_TOKEN_LIFETIME_MINUTES", 60)) * 60
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class User(models.BaseUser):
 | |
|     """
 | |
|     Base User Model
 | |
|     """
 | |
| 
 | |
|     name: Optional[str] = ""
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| # use custom model as model.BaseUserCreate includes is_* field
 | |
| class UserCreateIn(models.CreateUpdateDictModel):
 | |
|     """
 | |
|     User Creation Model exposed to API
 | |
|     """
 | |
| 
 | |
|     email: EmailStr
 | |
|     password: str
 | |
| 
 | |
|     name: Optional[str] = ""
 | |
| 
 | |
|     inviteToken: Optional[UUID4]
 | |
| 
 | |
|     newArchive: bool
 | |
|     newArchiveName: Optional[str] = ""
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class UserCreate(models.BaseUserCreate):
 | |
|     """
 | |
|     User Creation Model
 | |
|     """
 | |
| 
 | |
|     name: Optional[str] = ""
 | |
| 
 | |
|     inviteToken: Optional[UUID4]
 | |
| 
 | |
|     newArchive: bool
 | |
|     newArchiveName: Optional[str] = ""
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class UserUpdate(User, models.CreateUpdateDictModel):
 | |
|     """
 | |
|     User Update Model
 | |
|     """
 | |
| 
 | |
|     password: Optional[str]
 | |
|     email: Optional[EmailStr]
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class UserDB(User, models.BaseUserDB):
 | |
|     """
 | |
|     User in DB Model
 | |
|     """
 | |
| 
 | |
|     invites: Dict[str, InvitePending] = {}
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| # pylint: disable=too-few-public-methods
 | |
| class UserDBOps(MongoDBUserDatabase):
 | |
|     """User DB Operations wrapper"""
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class UserManager(BaseUserManager[UserCreate, UserDB]):
 | |
|     """Browsertrix UserManager"""
 | |
| 
 | |
|     user_db_model = UserDB
 | |
|     reset_password_token_secret = PASSWORD_SECRET
 | |
|     verification_token_secret = PASSWORD_SECRET
 | |
| 
 | |
|     def __init__(self, user_db, email, invites):
 | |
|         super().__init__(user_db)
 | |
|         self.email = email
 | |
|         self.invites = invites
 | |
|         self.archive_ops = None
 | |
| 
 | |
|         self.registration_enabled = os.environ.get("REGISTRATION_ENABLED") == "1"
 | |
| 
 | |
|     def set_archive_ops(self, ops):
 | |
|         """set archive ops"""
 | |
|         self.archive_ops = ops
 | |
| 
 | |
|     async def create(
 | |
|         self, user: UserCreate, safe: bool = False, request: Optional[Request] = None
 | |
|     ):
 | |
|         """override user creation to check if invite token is present"""
 | |
|         user.name = user.name or user.email
 | |
| 
 | |
|         # if open registration not enabled, can only register with an invite
 | |
|         if (
 | |
|             not self.registration_enabled
 | |
|             and not user.inviteToken
 | |
|             and not user.is_verified
 | |
|             and not user.is_superuser
 | |
|         ):
 | |
|             raise HTTPException(status_code=400, detail="Invite Token Required")
 | |
| 
 | |
|         if user.inviteToken and not await self.invites.get_valid_invite(
 | |
|             user.inviteToken, user.email
 | |
|         ):
 | |
|             raise HTTPException(status_code=400, detail="Invalid Invite Token")
 | |
| 
 | |
|         created_user = await super().create(user, safe, request)
 | |
|         await self.on_after_register_custom(created_user, user, request)
 | |
|         return created_user
 | |
| 
 | |
|     async def get_user_names_by_ids(self, user_ids):
 | |
|         """return list of user names for given ids"""
 | |
|         user_ids = [UUID4(id_) for id_ in user_ids]
 | |
|         cursor = self.user_db.collection.find(
 | |
|             {"id": {"$in": user_ids}}, projection=["id", "name"]
 | |
|         )
 | |
|         return await cursor.to_list(length=1000)
 | |
| 
 | |
|     async def create_super_user(self):
 | |
|         """Initialize a super user from env vars"""
 | |
|         email = os.environ.get("SUPERUSER_EMAIL")
 | |
|         password = os.environ.get("SUPERUSER_PASSWORD")
 | |
|         if not email:
 | |
|             print("No superuser defined", flush=True)
 | |
|             return
 | |
| 
 | |
|         if not password:
 | |
|             password = passlib.pwd.genword()
 | |
| 
 | |
|         try:
 | |
|             res = await self.create(
 | |
|                 UserCreate(
 | |
|                     email=email,
 | |
|                     password=password,
 | |
|                     is_superuser=True,
 | |
|                     newArchive=False,
 | |
|                     is_verified=True,
 | |
|                 )
 | |
|             )
 | |
|             print(f"Super user {email} created", flush=True)
 | |
|             print(res, flush=True)
 | |
| 
 | |
|         except (DuplicateKeyError, UserAlreadyExists):
 | |
|             print(f"User {email} already exists", flush=True)
 | |
| 
 | |
|     async def on_after_register_custom(
 | |
|         self, user: UserDB, user_create: UserCreate, request: Optional[Request]
 | |
|     ):
 | |
|         """custom post registration callback, also receive the UserCreate object"""
 | |
| 
 | |
|         print(f"User {user.id} has registered.")
 | |
| 
 | |
|         if user_create.newArchive:
 | |
|             print(f"Creating new archive for {user.id}")
 | |
| 
 | |
|             archive_name = (
 | |
|                 user_create.newArchiveName or f"{user.name or user.email}'s Archive"
 | |
|             )
 | |
| 
 | |
|             await self.archive_ops.create_new_archive_for_user(
 | |
|                 archive_name=archive_name,
 | |
|                 storage_name="default",
 | |
|                 user=user,
 | |
|             )
 | |
| 
 | |
|         is_verified = hasattr(user_create, "is_verified") and user_create.is_verified
 | |
| 
 | |
|         if user_create.inviteToken:
 | |
|             try:
 | |
|                 await self.archive_ops.handle_new_user_invite(
 | |
|                     user_create.inviteToken, user
 | |
|                 )
 | |
|             except HTTPException as exc:
 | |
|                 print(exc)
 | |
| 
 | |
|             if not is_verified:
 | |
|                 # if user has been invited, mark as verified immediately
 | |
|                 await self._update(user, {"is_verified": True})
 | |
| 
 | |
|         elif not is_verified:
 | |
|             asyncio.create_task(self.request_verify(user, request))
 | |
| 
 | |
|     async def on_after_forgot_password(
 | |
|         self, user: UserDB, token: str, request: Optional[Request] = None
 | |
|     ):
 | |
|         """callback after password forgot"""
 | |
|         print(f"User {user.id} has forgot their password. Reset token: {token}")
 | |
|         self.email.send_user_forgot_password(
 | |
|             user.email, token, request and request.headers
 | |
|         )
 | |
| 
 | |
|     async def on_after_request_verify(
 | |
|         self, user: UserDB, token: str, request: Optional[Request] = None
 | |
|     ):
 | |
|         """callback after verification request"""
 | |
| 
 | |
|         self.email.send_user_validation(user.email, token, request and request.headers)
 | |
| 
 | |
|     async def format_invite(self, invite):
 | |
|         """format an InvitePending to return via api, resolve name of inviter"""
 | |
|         inviter = await self.get_by_email(invite.inviterEmail)
 | |
|         result = invite.serialize()
 | |
|         result["inviterName"] = inviter.name
 | |
|         if invite.aid:
 | |
|             archive = await self.archive_ops.get_archive_for_user_by_id(
 | |
|                 invite.aid, inviter
 | |
|             )
 | |
|             result["archiveName"] = archive.name
 | |
| 
 | |
|         return result
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| def init_user_manager(mdb, emailsender, invites):
 | |
|     """
 | |
|     Load users table and init /users routes
 | |
|     """
 | |
| 
 | |
|     user_collection = mdb.get_collection("users")
 | |
| 
 | |
|     user_db = UserDBOps(UserDB, user_collection)
 | |
| 
 | |
|     return UserManager(user_db, emailsender, invites)
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class OA2BearerOrQuery(OAuth2PasswordBearer):
 | |
|     """Override bearer check to also test query"""
 | |
| 
 | |
|     async def __call__(
 | |
|         self, request: Request = None, websocket: WebSocket = None
 | |
|     ) -> Optional[str]:
 | |
|         param = None
 | |
|         exc = None
 | |
|         # use websocket as request if no request
 | |
|         request = request or websocket
 | |
|         try:
 | |
|             param = await super().__call__(request)
 | |
|             if param:
 | |
|                 return param
 | |
| 
 | |
|         # pylint: disable=broad-except
 | |
|         except Exception as super_exc:
 | |
|             exc = super_exc
 | |
| 
 | |
|         param = request.query_params.get("auth_bearer")
 | |
| 
 | |
|         if param:
 | |
|             return param
 | |
| 
 | |
|         if exc:
 | |
|             raise exc
 | |
| 
 | |
|         raise HTTPException(status_code=404, detail="Not Found")
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class BearerOrQueryTransport(BearerTransport):
 | |
|     """Bearer or Query Transport"""
 | |
| 
 | |
|     scheme: OA2BearerOrQuery
 | |
| 
 | |
|     def __init__(self, tokenUrl: str):
 | |
|         # pylint: disable=super-init-not-called
 | |
|         self.scheme = OA2BearerOrQuery(tokenUrl, auto_error=False)
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| def init_users_api(app, user_manager):
 | |
|     """init fastapi_users"""
 | |
|     bearer_transport = BearerOrQueryTransport(tokenUrl="auth/jwt/login")
 | |
| 
 | |
|     def get_jwt_strategy() -> JWTStrategy:
 | |
|         return JWTStrategy(secret=PASSWORD_SECRET, lifetime_seconds=JWT_TOKEN_LIFETIME)
 | |
| 
 | |
|     auth_backend = AuthenticationBackend(
 | |
|         name="jwt",
 | |
|         transport=bearer_transport,
 | |
|         get_strategy=get_jwt_strategy,
 | |
|     )
 | |
| 
 | |
|     fastapi_users = FastAPIUsers(
 | |
|         lambda: user_manager,
 | |
|         [auth_backend],
 | |
|         User,
 | |
|         UserCreateIn,
 | |
|         UserUpdate,
 | |
|         UserDB,
 | |
|     )
 | |
| 
 | |
|     auth_router = fastapi_users.get_auth_router(auth_backend)
 | |
| 
 | |
|     current_active_user = fastapi_users.current_user(active=True)
 | |
| 
 | |
|     @auth_router.post("/refresh")
 | |
|     async def refresh_jwt(response: Response, user=Depends(current_active_user)):
 | |
|         return await auth_backend.login(get_jwt_strategy(), user, response)
 | |
| 
 | |
|     app.include_router(
 | |
|         auth_router,
 | |
|         prefix="/auth/jwt",
 | |
|         tags=["auth"],
 | |
|     )
 | |
| 
 | |
|     app.include_router(
 | |
|         fastapi_users.get_register_router(),
 | |
|         prefix="/auth",
 | |
|         tags=["auth"],
 | |
|     )
 | |
|     app.include_router(
 | |
|         fastapi_users.get_reset_password_router(),
 | |
|         prefix="/auth",
 | |
|         tags=["auth"],
 | |
|     )
 | |
|     app.include_router(
 | |
|         fastapi_users.get_verify_router(),
 | |
|         prefix="/auth",
 | |
|         tags=["auth"],
 | |
|     )
 | |
| 
 | |
|     users_router = fastapi_users.get_users_router()
 | |
| 
 | |
|     @users_router.post("/invite", tags=["invites"])
 | |
|     async def invite_user(
 | |
|         invite: InviteRequest,
 | |
|         request: Request,
 | |
|         user: User = Depends(current_active_user),
 | |
|     ):
 | |
|         if not user.is_superuser:
 | |
|             raise HTTPException(status_code=403, detail="Not Allowed")
 | |
| 
 | |
|         await user_manager.invites.invite_user(
 | |
|             invite,
 | |
|             user,
 | |
|             user_manager,
 | |
|             archive=None,
 | |
|             allow_existing=False,
 | |
|             headers=request.headers,
 | |
|         )
 | |
| 
 | |
|         return {"invited": "new_user"}
 | |
| 
 | |
|     @users_router.get("/invite/{token}", tags=["invites"])
 | |
|     async def get_invite_info(token: str, email: str):
 | |
|         invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email)
 | |
|         return await user_manager.format_invite(invite)
 | |
| 
 | |
|     @users_router.get("/me/invite/{token}", tags=["invites"])
 | |
|     async def get_existing_user_invite_info(
 | |
|         token: str, user: User = Depends(current_active_user)
 | |
|     ):
 | |
| 
 | |
|         try:
 | |
|             invite = user.invites[token]
 | |
|         except:
 | |
|             # pylint: disable=raise-missing-from
 | |
|             raise HTTPException(status_code=400, detail="Invalid Invite Code")
 | |
| 
 | |
|         return await user_manager.format_invite(invite)
 | |
| 
 | |
|     app.include_router(users_router, prefix="/users", tags=["users"])
 | |
| 
 | |
|     asyncio.create_task(user_manager.create_super_user())
 | |
| 
 | |
|     return fastapi_users
 |