browsertrix/backend/btrixcloud/users.py
Ilya Kreymer 7aefe09581
startup fixes: (#793)
- don't run migrations on first init, just set to CURR_DB_VERSION
- implement 'run once lock' with mkdir/rmdir
- move register_exit_handler() to utils
- remove old run once handler
2023-04-24 18:32:52 -07:00

493 lines
15 KiB
Python

"""
FastAPI user handling (via fastapi-users)
"""
import os
import uuid
import asyncio
from typing import Dict, Optional
from pydantic import EmailStr, UUID4
import passlib.pwd
from fastapi import Request, Response, HTTPException, Depends, WebSocket
from fastapi.security import OAuth2PasswordBearer
from pymongo.errors import DuplicateKeyError
from fastapi_users import FastAPIUsers, models, BaseUserManager
from fastapi_users.manager import UserAlreadyExists
from fastapi_users.authentication import (
AuthenticationBackend,
BearerTransport,
JWTStrategy,
)
from fastapi_users.db import MongoDBUserDatabase
from .invites import InvitePending, UserRole
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
# ============================================================================
PASSWORD_SECRET = os.environ.get("PASSWORD_SECRET", uuid.uuid4().hex)
JWT_TOKEN_LIFETIME = int(os.environ.get("JWT_TOKEN_LIFETIME_MINUTES", 60)) * 60
# ============================================================================
class User(models.BaseUser):
"""
Base User Model
"""
name: Optional[str] = ""
# ============================================================================
# use custom model as model.BaseUserCreate includes is_* field
class UserCreateIn(models.CreateUpdateDictModel):
"""
User Creation Model exposed to API
"""
email: EmailStr
password: str
name: Optional[str] = ""
inviteToken: Optional[UUID4]
newOrg: bool
newOrgName: Optional[str] = ""
# ============================================================================
class UserCreate(models.BaseUserCreate):
"""
User Creation Model
"""
name: Optional[str] = ""
inviteToken: Optional[UUID4]
newOrg: bool
newOrgName: Optional[str] = ""
# ============================================================================
class UserUpdate(User, models.CreateUpdateDictModel):
"""
User Update Model
"""
password: Optional[str]
email: Optional[EmailStr]
# ============================================================================
class UserDB(User, models.BaseUserDB):
"""
User in DB Model
"""
invites: Dict[str, InvitePending] = {}
# ============================================================================
# pylint: disable=too-few-public-methods
class UserDBOps(MongoDBUserDatabase):
"""User DB Operations wrapper"""
# ============================================================================
class UserManager(BaseUserManager[UserCreate, UserDB]):
"""Browsertrix UserManager"""
user_db_model = UserDB
reset_password_token_secret = PASSWORD_SECRET
verification_token_secret = PASSWORD_SECRET
def __init__(self, user_db, email, invites):
super().__init__(user_db)
self.email = email
self.invites = invites
self.org_ops = None
self.registration_enabled = os.environ.get("REGISTRATION_ENABLED") == "1"
def set_org_ops(self, ops):
"""set org ops"""
self.org_ops = ops
async def create(
self, user: UserCreate, safe: bool = False, request: Optional[Request] = None
):
"""override user creation to check if invite token is present"""
user.name = user.name or user.email
# if open registration not enabled, can only register with an invite
if (
not self.registration_enabled
and not user.inviteToken
and not user.is_verified
and not user.is_superuser
):
raise HTTPException(status_code=400, detail="Invite Token Required")
if user.inviteToken and not await self.invites.get_valid_invite(
user.inviteToken, user.email
):
raise HTTPException(status_code=400, detail="Invalid Invite Token")
# Don't create a new org for registered users.
user.newOrg = False
created_user = await super().create(user, safe, request)
await self.on_after_register_custom(created_user, user, request)
return created_user
async def get_user_names_by_ids(self, user_ids):
"""return list of user names for given ids"""
user_ids = [UUID4(id_) for id_ in user_ids]
cursor = self.user_db.collection.find(
{"id": {"$in": user_ids}}, projection=["id", "name", "email"]
)
return await cursor.to_list(length=1000)
async def get_superuser(self):
"""return current superuser, if any"""
return await self.user_db.collection.find_one({"is_superuser": True})
async def create_super_user(self):
"""Initialize a super user from env vars"""
email = os.environ.get("SUPERUSER_EMAIL")
password = os.environ.get("SUPERUSER_PASSWORD")
if not email:
print("No superuser defined", flush=True)
return
if not password:
password = passlib.pwd.genword()
curr_superuser_res = await self.get_superuser()
if curr_superuser_res:
user = UserDB(**curr_superuser_res)
update = {"password": password}
if user.email != email:
update["email"] = email
try:
await self._update(user, update)
print("Superuser updated")
except UserAlreadyExists:
print(f"User {email} already exists", flush=True)
return
try:
res = await self.create(
UserCreate(
name="admin",
email=email,
password=password,
is_superuser=True,
newOrg=False,
is_verified=True,
)
)
print(f"Super user {email} created", flush=True)
print(res, flush=True)
except (DuplicateKeyError, UserAlreadyExists):
print(f"User {email} already exists", flush=True)
async def create_non_super_user(
self,
email: str,
password: str,
name: str = "New user",
):
"""create a regular user with given credentials"""
if not email:
print("No user defined", flush=True)
return
if not password:
password = passlib.pwd.genword()
try:
user_create = UserCreate(
name=name,
email=email,
password=password,
is_superuser=False,
newOrg=False,
is_verified=True,
)
created_user = await super().create(user_create, safe=False, request=None)
await self.on_after_register_custom(created_user, user_create, request=None)
return created_user
except (DuplicateKeyError, UserAlreadyExists):
print(f"User {email} already exists", flush=True)
async def on_after_register_custom(
self, user: UserDB, user_create: UserCreate, request: Optional[Request]
):
"""custom post registration callback, also receive the UserCreate object"""
print(f"User {user.id} has registered.")
add_to_default_org = False
if user_create.newOrg is True:
print(f"Creating new organization for {user.id}")
org_name = (
user_create.newOrgName or f"{user.name or user.email}'s Organization"
)
await self.org_ops.create_new_org_for_user(
org_name=org_name,
storage_name="default",
user=user,
)
is_verified = hasattr(user_create, "is_verified") and user_create.is_verified
if user_create.inviteToken:
new_user_invite = None
try:
new_user_invite = await self.org_ops.handle_new_user_invite(
user_create.inviteToken, user
)
except HTTPException as exc:
print(exc)
if new_user_invite and not new_user_invite.oid:
add_to_default_org = True
if not is_verified:
# if user has been invited, mark as verified immediately
await self._update(user, {"is_verified": True})
else:
add_to_default_org = True
if not is_verified:
asyncio.create_task(self.request_verify(user, request))
if add_to_default_org:
default_org = await self.org_ops.get_default_org()
if default_org:
await self.org_ops.add_user_to_org(default_org, user.id)
async def on_after_forgot_password(
self, user: UserDB, token: str, request: Optional[Request] = None
):
"""callback after password forgot"""
print(f"User {user.id} has forgot their password. Reset token: {token}")
self.email.send_user_forgot_password(
user.email, token, request and request.headers
)
async def on_after_request_verify(
self, user: UserDB, token: str, request: Optional[Request] = None
):
"""callback after verification request"""
self.email.send_user_validation(user.email, token, request and request.headers)
async def format_invite(self, invite):
"""format an InvitePending to return via api, resolve name of inviter"""
inviter = await self.get_by_email(invite.inviterEmail)
result = invite.serialize()
result["inviterName"] = inviter.name
if invite.oid:
org = await self.org_ops.get_org_for_user_by_id(invite.oid, inviter)
result["orgName"] = org.name
return result
# ============================================================================
def init_user_manager(mdb, emailsender, invites):
"""
Load users table and init /users routes
"""
user_collection = mdb.get_collection("users")
user_db = UserDBOps(UserDB, user_collection)
return UserManager(user_db, emailsender, invites)
# ============================================================================
class OA2BearerOrQuery(OAuth2PasswordBearer):
"""Override bearer check to also test query"""
async def __call__(
self, request: Request = None, websocket: WebSocket = None
) -> Optional[str]:
param = None
exc = None
# use websocket as request if no request
request = request or websocket
try:
param = await super().__call__(request)
if param:
return param
# pylint: disable=broad-except
except Exception as super_exc:
exc = super_exc
param = request.query_params.get("auth_bearer")
if param:
return param
if exc:
raise exc
raise HTTPException(status_code=404, detail="Not Found")
# ============================================================================
class BearerOrQueryTransport(BearerTransport):
"""Bearer or Query Transport"""
scheme: OA2BearerOrQuery
def __init__(self, tokenUrl: str):
# pylint: disable=super-init-not-called
self.scheme = OA2BearerOrQuery(tokenUrl, auto_error=False)
# ============================================================================
# pylint: disable=too-many-locals
def init_users_api(app, user_manager):
"""init fastapi_users"""
# pylint: disable=invalid-name
bearer_transport = BearerOrQueryTransport(tokenUrl="auth/jwt/login")
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(secret=PASSWORD_SECRET, lifetime_seconds=JWT_TOKEN_LIFETIME)
auth_backend = AuthenticationBackend(
name="jwt",
transport=bearer_transport,
get_strategy=get_jwt_strategy,
)
fastapi_users = FastAPIUsers(
lambda: user_manager,
[auth_backend],
User,
UserCreateIn,
UserUpdate,
UserDB,
)
auth_router = fastapi_users.get_auth_router(auth_backend)
current_active_user = fastapi_users.current_user(active=True)
@auth_router.post("/refresh")
async def refresh_jwt(response: Response, user=Depends(current_active_user)):
return await auth_backend.login(get_jwt_strategy(), user, response)
app.include_router(
auth_router,
prefix="/auth/jwt",
tags=["auth"],
)
app.include_router(
fastapi_users.get_register_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_verify_router(),
prefix="/auth",
tags=["auth"],
)
users_router = fastapi_users.get_users_router()
@users_router.get("/me-with-orgs", tags=["users"])
async def me_with_org_info(user: User = Depends(current_active_user)):
"""/users/me with orgs user belongs to."""
user_info = {
"id": user.id,
"email": user.email,
"name": user.name,
"orgs": [],
"is_active": user.is_active,
"is_superuser": user.is_superuser,
"is_verified": user.is_verified,
}
user_orgs, _ = await user_manager.org_ops.get_orgs_for_user(
user,
# Set high so that we get all orgs even after reducing default page size
page_size=1_000,
calculate_total=False,
)
if user_orgs:
user_info["orgs"] = [
{
"id": org.id,
"name": org.name,
"default": org.default,
"role": UserRole.SUPERADMIN
if user.is_superuser
else org.users.get(str(user.id)),
}
for org in user_orgs
]
print(f"user info with orgs: {user_info}", flush=True)
return user_info
@users_router.get("/invite/{token}", tags=["invites"])
async def get_invite_info(token: str, email: str):
invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email)
return await user_manager.format_invite(invite)
@users_router.get("/me/invite/{token}", tags=["invites"])
async def get_existing_user_invite_info(
token: str, user: User = Depends(current_active_user)
):
try:
invite = user.invites[token]
except:
# pylint: disable=raise-missing-from
raise HTTPException(status_code=400, detail="Invalid Invite Code")
return await user_manager.format_invite(invite)
@users_router.get("/invites", tags=["invites"])
async def get_pending_invites(
user: User = Depends(current_active_user),
pageSize: int = DEFAULT_PAGE_SIZE,
page: int = 1,
):
if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed")
pending_invites, total = await user_manager.invites.get_pending_invites(
page_size=pageSize, page=page
)
return paginated_format(pending_invites, total, page, pageSize)
app.include_router(users_router, prefix="/users", tags=["users"])
return fastapi_users