browsertrix/backend/btrixcloud/users.py
2022-10-19 21:47:34 -07:00

402 lines
12 KiB
Python

"""
FastAPI user handling (via fastapi-users)
"""
import os
import uuid
import asyncio
from typing import Dict, Optional
from pydantic import EmailStr, UUID4
import passlib.pwd
from fastapi import Request, Response, HTTPException, Depends, WebSocket
from fastapi.security import OAuth2PasswordBearer
from pymongo.errors import DuplicateKeyError
from fastapi_users import FastAPIUsers, models, BaseUserManager
from fastapi_users.manager import UserAlreadyExists
from fastapi_users.authentication import (
AuthenticationBackend,
BearerTransport,
JWTStrategy,
)
from fastapi_users.db import MongoDBUserDatabase
from .invites import InvitePending, InviteRequest
# ============================================================================
PASSWORD_SECRET = os.environ.get("PASSWORD_SECRET", uuid.uuid4().hex)
JWT_TOKEN_LIFETIME = int(os.environ.get("JWT_TOKEN_LIFETIME_MINUTES", 60)) * 60
# ============================================================================
class User(models.BaseUser):
"""
Base User Model
"""
name: Optional[str] = ""
# ============================================================================
# use custom model as model.BaseUserCreate includes is_* field
class UserCreateIn(models.CreateUpdateDictModel):
"""
User Creation Model exposed to API
"""
email: EmailStr
password: str
name: Optional[str] = ""
inviteToken: Optional[UUID4]
newArchive: bool
newArchiveName: Optional[str] = ""
# ============================================================================
class UserCreate(models.BaseUserCreate):
"""
User Creation Model
"""
name: Optional[str] = ""
inviteToken: Optional[UUID4]
newArchive: bool
newArchiveName: Optional[str] = ""
# ============================================================================
class UserUpdate(User, models.CreateUpdateDictModel):
"""
User Update Model
"""
password: Optional[str]
email: Optional[EmailStr]
# ============================================================================
class UserDB(User, models.BaseUserDB):
"""
User in DB Model
"""
invites: Dict[str, InvitePending] = {}
# ============================================================================
# pylint: disable=too-few-public-methods
class UserDBOps(MongoDBUserDatabase):
"""User DB Operations wrapper"""
# ============================================================================
class UserManager(BaseUserManager[UserCreate, UserDB]):
"""Browsertrix UserManager"""
user_db_model = UserDB
reset_password_token_secret = PASSWORD_SECRET
verification_token_secret = PASSWORD_SECRET
def __init__(self, user_db, email, invites):
super().__init__(user_db)
self.email = email
self.invites = invites
self.archive_ops = None
self.registration_enabled = os.environ.get("REGISTRATION_ENABLED") == "1"
def set_archive_ops(self, ops):
"""set archive ops"""
self.archive_ops = ops
async def create(
self, user: UserCreate, safe: bool = False, request: Optional[Request] = None
):
"""override user creation to check if invite token is present"""
user.name = user.name or user.email
# if open registration not enabled, can only register with an invite
if (
not self.registration_enabled
and not user.inviteToken
and not user.is_verified
and not user.is_superuser
):
raise HTTPException(status_code=400, detail="Invite Token Required")
if user.inviteToken and not await self.invites.get_valid_invite(
user.inviteToken, user.email
):
raise HTTPException(status_code=400, detail="Invalid Invite Token")
created_user = await super().create(user, safe, request)
await self.on_after_register_custom(created_user, user, request)
return created_user
async def get_user_names_by_ids(self, user_ids):
"""return list of user names for given ids"""
user_ids = [UUID4(id_) for id_ in user_ids]
cursor = self.user_db.collection.find(
{"id": {"$in": user_ids}}, projection=["id", "name"]
)
return await cursor.to_list(length=1000)
async def create_super_user(self):
"""Initialize a super user from env vars"""
email = os.environ.get("SUPERUSER_EMAIL")
password = os.environ.get("SUPERUSER_PASSWORD")
if not email:
print("No superuser defined", flush=True)
return
if not password:
password = passlib.pwd.genword()
try:
res = await self.create(
UserCreate(
email=email,
password=password,
is_superuser=True,
newArchive=False,
is_verified=True,
)
)
print(f"Super user {email} created", flush=True)
print(res, flush=True)
except (DuplicateKeyError, UserAlreadyExists):
print(f"User {email} already exists", flush=True)
async def on_after_register_custom(
self, user: UserDB, user_create: UserCreate, request: Optional[Request]
):
"""custom post registration callback, also receive the UserCreate object"""
print(f"User {user.id} has registered.")
if user_create.newArchive:
print(f"Creating new archive for {user.id}")
archive_name = (
user_create.newArchiveName or f"{user.name or user.email}'s Archive"
)
await self.archive_ops.create_new_archive_for_user(
archive_name=archive_name,
storage_name="default",
user=user,
)
is_verified = hasattr(user_create, "is_verified") and user_create.is_verified
if user_create.inviteToken:
try:
await self.archive_ops.handle_new_user_invite(
user_create.inviteToken, user
)
except HTTPException as exc:
print(exc)
if not is_verified:
# if user has been invited, mark as verified immediately
await self._update(user, {"is_verified": True})
elif not is_verified:
asyncio.create_task(self.request_verify(user, request))
async def on_after_forgot_password(
self, user: UserDB, token: str, request: Optional[Request] = None
):
"""callback after password forgot"""
print(f"User {user.id} has forgot their password. Reset token: {token}")
self.email.send_user_forgot_password(
user.email, token, request and request.headers
)
async def on_after_request_verify(
self, user: UserDB, token: str, request: Optional[Request] = None
):
"""callback after verification request"""
self.email.send_user_validation(user.email, token, request and request.headers)
async def format_invite(self, invite):
"""format an InvitePending to return via api, resolve name of inviter"""
inviter = await self.get_by_email(invite.inviterEmail)
result = invite.serialize()
result["inviterName"] = inviter.name
if invite.aid:
archive = await self.archive_ops.get_archive_for_user_by_id(
invite.aid, inviter
)
result["archiveName"] = archive.name
return result
# ============================================================================
def init_user_manager(mdb, emailsender, invites):
"""
Load users table and init /users routes
"""
user_collection = mdb.get_collection("users")
user_db = UserDBOps(UserDB, user_collection)
return UserManager(user_db, emailsender, invites)
# ============================================================================
class OA2BearerOrQuery(OAuth2PasswordBearer):
"""Override bearer check to also test query"""
async def __call__(
self, request: Request = None, websocket: WebSocket = None
) -> Optional[str]:
param = None
exc = None
# use websocket as request if no request
request = request or websocket
try:
param = await super().__call__(request)
if param:
return param
# pylint: disable=broad-except
except Exception as super_exc:
exc = super_exc
param = request.query_params.get("auth_bearer")
if param:
return param
if exc:
raise exc
raise HTTPException(status_code=404, detail="Not Found")
# ============================================================================
class BearerOrQueryTransport(BearerTransport):
"""Bearer or Query Transport"""
scheme: OA2BearerOrQuery
def __init__(self, tokenUrl: str):
# pylint: disable=super-init-not-called
self.scheme = OA2BearerOrQuery(tokenUrl, auto_error=False)
# ============================================================================
def init_users_api(app, user_manager):
"""init fastapi_users"""
bearer_transport = BearerOrQueryTransport(tokenUrl="auth/jwt/login")
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(secret=PASSWORD_SECRET, lifetime_seconds=JWT_TOKEN_LIFETIME)
auth_backend = AuthenticationBackend(
name="jwt",
transport=bearer_transport,
get_strategy=get_jwt_strategy,
)
fastapi_users = FastAPIUsers(
lambda: user_manager,
[auth_backend],
User,
UserCreateIn,
UserUpdate,
UserDB,
)
auth_router = fastapi_users.get_auth_router(auth_backend)
current_active_user = fastapi_users.current_user(active=True)
@auth_router.post("/refresh")
async def refresh_jwt(response: Response, user=Depends(current_active_user)):
return await auth_backend.login(get_jwt_strategy(), user, response)
app.include_router(
auth_router,
prefix="/auth/jwt",
tags=["auth"],
)
app.include_router(
fastapi_users.get_register_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_verify_router(),
prefix="/auth",
tags=["auth"],
)
users_router = fastapi_users.get_users_router()
@users_router.post("/invite", tags=["invites"])
async def invite_user(
invite: InviteRequest,
request: Request,
user: User = Depends(current_active_user),
):
if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed")
await user_manager.invites.invite_user(
invite,
user,
user_manager,
archive=None,
allow_existing=False,
headers=request.headers,
)
return {"invited": "new_user"}
@users_router.get("/invite/{token}", tags=["invites"])
async def get_invite_info(token: str, email: str):
invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email)
return await user_manager.format_invite(invite)
@users_router.get("/me/invite/{token}", tags=["invites"])
async def get_existing_user_invite_info(
token: str, user: User = Depends(current_active_user)
):
try:
invite = user.invites[token]
except:
# pylint: disable=raise-missing-from
raise HTTPException(status_code=400, detail="Invalid Invite Code")
return await user_manager.format_invite(invite)
app.include_router(users_router, prefix="/users", tags=["users"])
asyncio.create_task(user_manager.create_super_user())
return fastapi_users