473 lines
15 KiB
Python
473 lines
15 KiB
Python
"""
|
|
FastAPI user handling (via fastapi-users)
|
|
"""
|
|
|
|
import os
|
|
import uuid
|
|
import asyncio
|
|
|
|
from typing import Optional, Union
|
|
|
|
from pydantic import UUID4
|
|
import passlib.pwd
|
|
|
|
from fastapi import Request, Response, HTTPException, Depends, WebSocket
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
|
from pymongo.errors import DuplicateKeyError
|
|
|
|
from fastapi_users import FastAPIUsers, BaseUserManager
|
|
from fastapi_users.manager import UserAlreadyExists, InvalidPasswordException
|
|
from fastapi_users.authentication import (
|
|
AuthenticationBackend,
|
|
BearerTransport,
|
|
JWTStrategy,
|
|
)
|
|
from fastapi_users.db import MongoDBUserDatabase
|
|
|
|
from .models import (
|
|
User,
|
|
UserCreateIn,
|
|
UserCreate,
|
|
UserUpdate,
|
|
UserDB,
|
|
UserRole,
|
|
PaginatedResponse,
|
|
)
|
|
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
|
|
from .utils import is_bool
|
|
|
|
# ============================================================================
|
|
PASSWORD_SECRET = os.environ.get("PASSWORD_SECRET", uuid.uuid4().hex)
|
|
|
|
JWT_TOKEN_LIFETIME = int(os.environ.get("JWT_TOKEN_LIFETIME_MINUTES", 60)) * 60
|
|
|
|
|
|
# ============================================================================
|
|
# pylint: disable=too-few-public-methods
|
|
class UserDBOps(MongoDBUserDatabase):
|
|
"""User DB Operations wrapper"""
|
|
|
|
|
|
# ============================================================================
|
|
class UserManager(BaseUserManager[UserCreate, UserDB]):
|
|
"""Browsertrix UserManager"""
|
|
|
|
user_db_model = UserDB
|
|
reset_password_token_secret = PASSWORD_SECRET
|
|
verification_token_secret = PASSWORD_SECRET
|
|
|
|
def __init__(self, user_db, email, invites):
|
|
super().__init__(user_db)
|
|
self.email = email
|
|
self.invites = invites
|
|
self.org_ops = None
|
|
|
|
self.registration_enabled = is_bool(os.environ.get("REGISTRATION_ENABLED"))
|
|
|
|
def set_org_ops(self, ops):
|
|
"""set org ops"""
|
|
self.org_ops = ops
|
|
|
|
async def create(
|
|
self, user: UserCreate, safe: bool = False, request: Optional[Request] = None
|
|
):
|
|
"""override user creation to check if invite token is present"""
|
|
user.name = user.name or user.email
|
|
|
|
# if open registration not enabled, can only register with an invite
|
|
if (
|
|
not self.registration_enabled
|
|
and not user.inviteToken
|
|
and not user.is_verified
|
|
and not user.is_superuser
|
|
):
|
|
raise HTTPException(status_code=400, detail="Invite Token Required")
|
|
|
|
if user.inviteToken and not await self.invites.get_valid_invite(
|
|
user.inviteToken, user.email
|
|
):
|
|
raise HTTPException(status_code=400, detail="Invalid Invite Token")
|
|
|
|
# Don't create a new org for registered users.
|
|
user.newOrg = False
|
|
|
|
created_user = await super().create(user, safe, request)
|
|
await self.on_after_register_custom(created_user, user, request)
|
|
return created_user
|
|
|
|
async def validate_password(
|
|
self, password: str, user: Union[UserCreate, UserDB]
|
|
) -> None:
|
|
"""
|
|
Validate a password.
|
|
|
|
Overloaded to set password requirements.
|
|
|
|
:param password: The password to validate.
|
|
:param user: The user associated to this password.
|
|
:raises InvalidPasswordException: The password is invalid.
|
|
:return: None if the password is valid.
|
|
"""
|
|
pw_length = len(password)
|
|
if not 8 <= pw_length <= 64:
|
|
raise InvalidPasswordException(reason="invalid_password_length")
|
|
|
|
async def get_user_names_by_ids(self, user_ids):
|
|
"""return list of user names for given ids"""
|
|
user_ids = [UUID4(id_) for id_ in user_ids]
|
|
cursor = self.user_db.collection.find(
|
|
{"id": {"$in": user_ids}}, projection=["id", "name", "email"]
|
|
)
|
|
return await cursor.to_list(length=1000)
|
|
|
|
async def get_user_by_id(self, user_id: uuid.UUID):
|
|
"""return user from user_id"""
|
|
return await self.user_db.get(user_id)
|
|
|
|
async def get_superuser(self):
|
|
"""return current superuser, if any"""
|
|
return await self.user_db.collection.find_one({"is_superuser": True})
|
|
|
|
async def create_super_user(self):
|
|
"""Initialize a super user from env vars"""
|
|
email = os.environ.get("SUPERUSER_EMAIL")
|
|
password = os.environ.get("SUPERUSER_PASSWORD")
|
|
if not email:
|
|
print("No superuser defined", flush=True)
|
|
return
|
|
|
|
if not password:
|
|
password = passlib.pwd.genword()
|
|
|
|
curr_superuser_res = await self.get_superuser()
|
|
if curr_superuser_res:
|
|
user = UserDB(**curr_superuser_res)
|
|
update = {"password": password}
|
|
if user.email != email:
|
|
update["email"] = email
|
|
|
|
try:
|
|
await self._update(user, update)
|
|
print("Superuser updated")
|
|
except UserAlreadyExists:
|
|
print(f"User {email} already exists", flush=True)
|
|
|
|
return
|
|
|
|
try:
|
|
res = await self.create(
|
|
UserCreate(
|
|
name="admin",
|
|
email=email,
|
|
password=password,
|
|
is_superuser=True,
|
|
newOrg=False,
|
|
is_verified=True,
|
|
)
|
|
)
|
|
print(f"Super user {email} created", flush=True)
|
|
print(res, flush=True)
|
|
except (DuplicateKeyError, UserAlreadyExists):
|
|
print(f"User {email} already exists", flush=True)
|
|
# pylint: disable=raise-missing-from
|
|
except InvalidPasswordException:
|
|
raise HTTPException(status_code=422, detail="invalid_password")
|
|
|
|
async def create_non_super_user(
|
|
self,
|
|
email: str,
|
|
password: str,
|
|
name: str = "New user",
|
|
):
|
|
"""create a regular user with given credentials"""
|
|
if not email:
|
|
print("No user defined", flush=True)
|
|
return
|
|
|
|
if not password:
|
|
password = passlib.pwd.genword()
|
|
|
|
try:
|
|
user_create = UserCreate(
|
|
name=name,
|
|
email=email,
|
|
password=password,
|
|
is_superuser=False,
|
|
newOrg=False,
|
|
is_verified=True,
|
|
)
|
|
try:
|
|
created_user = await super().create(
|
|
user_create, safe=False, request=None
|
|
)
|
|
await self.on_after_register_custom(
|
|
created_user, user_create, request=None
|
|
)
|
|
return created_user
|
|
# pylint: disable=raise-missing-from
|
|
except InvalidPasswordException:
|
|
raise HTTPException(status_code=422, detail="invalid_password")
|
|
|
|
except (DuplicateKeyError, UserAlreadyExists):
|
|
print(f"User {email} already exists", flush=True)
|
|
|
|
async def on_after_register_custom(
|
|
self, user: UserDB, user_create: UserCreate, request: Optional[Request]
|
|
):
|
|
"""custom post registration callback, also receive the UserCreate object"""
|
|
|
|
print(f"User {user.id} has registered.")
|
|
add_to_default_org = False
|
|
|
|
if user_create.newOrg is True:
|
|
print(f"Creating new organization for {user.id}")
|
|
|
|
org_name = (
|
|
user_create.newOrgName or f"{user.name or user.email}'s Organization"
|
|
)
|
|
|
|
await self.org_ops.create_new_org_for_user(
|
|
org_name=org_name,
|
|
storage_name="default",
|
|
user=user,
|
|
)
|
|
|
|
is_verified = hasattr(user_create, "is_verified") and user_create.is_verified
|
|
|
|
if user_create.inviteToken:
|
|
new_user_invite = None
|
|
try:
|
|
new_user_invite = await self.org_ops.handle_new_user_invite(
|
|
user_create.inviteToken, user
|
|
)
|
|
except HTTPException as exc:
|
|
print(exc)
|
|
|
|
if new_user_invite and not new_user_invite.oid:
|
|
add_to_default_org = True
|
|
|
|
if not is_verified:
|
|
# if user has been invited, mark as verified immediately
|
|
await self._update(user, {"is_verified": True})
|
|
|
|
else:
|
|
add_to_default_org = True
|
|
if not is_verified:
|
|
asyncio.create_task(self.request_verify(user, request))
|
|
|
|
if add_to_default_org:
|
|
default_org = await self.org_ops.get_default_org()
|
|
if default_org:
|
|
await self.org_ops.add_user_to_org(default_org, user.id)
|
|
|
|
async def on_after_forgot_password(
|
|
self, user: UserDB, token: str, request: Optional[Request] = None
|
|
):
|
|
"""callback after password forgot"""
|
|
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
|
self.email.send_user_forgot_password(
|
|
user.email, token, request and request.headers
|
|
)
|
|
|
|
async def on_after_request_verify(
|
|
self, user: UserDB, token: str, request: Optional[Request] = None
|
|
):
|
|
"""callback after verification request"""
|
|
|
|
self.email.send_user_validation(user.email, token, request and request.headers)
|
|
|
|
async def format_invite(self, invite):
|
|
"""format an InvitePending to return via api, resolve name of inviter"""
|
|
inviter = await self.get_by_email(invite.inviterEmail)
|
|
result = invite.serialize()
|
|
result["inviterName"] = inviter.name
|
|
if invite.oid:
|
|
org = await self.org_ops.get_org_for_user_by_id(invite.oid, inviter)
|
|
result["orgName"] = org.name
|
|
|
|
return result
|
|
|
|
|
|
# ============================================================================
|
|
def init_user_manager(mdb, emailsender, invites):
|
|
"""
|
|
Load users table and init /users routes
|
|
"""
|
|
|
|
user_collection = mdb.get_collection("users")
|
|
|
|
user_db = UserDBOps(UserDB, user_collection)
|
|
|
|
return UserManager(user_db, emailsender, invites)
|
|
|
|
|
|
# ============================================================================
|
|
class OA2BearerOrQuery(OAuth2PasswordBearer):
|
|
"""Override bearer check to also test query"""
|
|
|
|
async def __call__(
|
|
self, request: Request = None, websocket: WebSocket = None # type: ignore
|
|
) -> Optional[str]:
|
|
param = None
|
|
exc = None
|
|
# use websocket as request if no request
|
|
request = request or websocket # type: ignore
|
|
try:
|
|
param = await super().__call__(request) # type: ignore
|
|
if param:
|
|
return param
|
|
|
|
# pylint: disable=broad-except
|
|
except Exception as super_exc:
|
|
exc = super_exc
|
|
|
|
if request:
|
|
param = request.query_params.get("auth_bearer")
|
|
|
|
if param:
|
|
return param
|
|
|
|
if exc:
|
|
raise exc
|
|
|
|
raise HTTPException(status_code=404, detail="Not Found")
|
|
|
|
|
|
# ============================================================================
|
|
class BearerOrQueryTransport(BearerTransport):
|
|
"""Bearer or Query Transport"""
|
|
|
|
scheme: OA2BearerOrQuery
|
|
|
|
# pylint: disable=invalid-name
|
|
def __init__(self, tokenUrl: str):
|
|
# pylint: disable=super-init-not-called
|
|
self.scheme = OA2BearerOrQuery(tokenUrl, auto_error=False)
|
|
|
|
|
|
# ============================================================================
|
|
# pylint: disable=too-many-locals
|
|
def init_users_api(app, user_manager):
|
|
"""init fastapi_users"""
|
|
# pylint: disable=invalid-name
|
|
|
|
bearer_transport = BearerOrQueryTransport(tokenUrl="auth/jwt/login")
|
|
|
|
def get_jwt_strategy() -> JWTStrategy:
|
|
return JWTStrategy(secret=PASSWORD_SECRET, lifetime_seconds=JWT_TOKEN_LIFETIME)
|
|
|
|
auth_backend = AuthenticationBackend(
|
|
name="jwt",
|
|
transport=bearer_transport,
|
|
get_strategy=get_jwt_strategy,
|
|
)
|
|
|
|
fastapi_users = FastAPIUsers(
|
|
lambda: user_manager,
|
|
[auth_backend],
|
|
User,
|
|
UserCreateIn,
|
|
UserUpdate,
|
|
UserDB,
|
|
)
|
|
|
|
auth_router = fastapi_users.get_auth_router(auth_backend)
|
|
|
|
current_active_user = fastapi_users.current_user(active=True)
|
|
|
|
@auth_router.post("/refresh")
|
|
async def refresh_jwt(response: Response, user=Depends(current_active_user)):
|
|
return await auth_backend.login(get_jwt_strategy(), user, response)
|
|
|
|
app.include_router(
|
|
auth_router,
|
|
prefix="/auth/jwt",
|
|
tags=["auth"],
|
|
)
|
|
|
|
app.include_router(
|
|
fastapi_users.get_register_router(),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_reset_password_router(),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_verify_router(),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
|
|
users_router = fastapi_users.get_users_router()
|
|
|
|
@users_router.get("/me-with-orgs", tags=["users"])
|
|
async def me_with_org_info(user: User = Depends(current_active_user)):
|
|
"""/users/me with orgs user belongs to."""
|
|
user_info = {
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"orgs": [],
|
|
"is_active": user.is_active,
|
|
"is_superuser": user.is_superuser,
|
|
"is_verified": user.is_verified,
|
|
}
|
|
user_orgs, _ = await user_manager.org_ops.get_orgs_for_user(
|
|
user,
|
|
# Set high so that we get all orgs even after reducing default page size
|
|
page_size=1_000,
|
|
calculate_total=False,
|
|
)
|
|
if user_orgs:
|
|
user_info["orgs"] = [
|
|
{
|
|
"id": org.id,
|
|
"name": org.name,
|
|
"default": org.default,
|
|
"role": UserRole.SUPERADMIN
|
|
if user.is_superuser
|
|
else org.users.get(str(user.id)),
|
|
}
|
|
for org in user_orgs
|
|
]
|
|
return user_info
|
|
|
|
@users_router.get("/invite/{token}", tags=["invites"])
|
|
async def get_invite_info(token: str, email: str):
|
|
invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email)
|
|
return await user_manager.format_invite(invite)
|
|
|
|
@users_router.get("/me/invite/{token}", tags=["invites"])
|
|
async def get_existing_user_invite_info(
|
|
token: str, user: UserDB = Depends(current_active_user)
|
|
):
|
|
try:
|
|
invite = user.invites[token]
|
|
except:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(status_code=400, detail="Invalid Invite Code")
|
|
|
|
return await user_manager.format_invite(invite)
|
|
|
|
@users_router.get("/invites", tags=["invites"], response_model=PaginatedResponse)
|
|
async def get_pending_invites(
|
|
user: User = Depends(current_active_user),
|
|
pageSize: int = DEFAULT_PAGE_SIZE,
|
|
page: int = 1,
|
|
):
|
|
if not user.is_superuser:
|
|
raise HTTPException(status_code=403, detail="Not Allowed")
|
|
|
|
pending_invites, total = await user_manager.invites.get_pending_invites(
|
|
page_size=pageSize, page=page
|
|
)
|
|
return paginated_format(pending_invites, total, page, pageSize)
|
|
|
|
app.include_router(users_router, prefix="/users", tags=["users"])
|
|
|
|
return fastapi_users
|