More stringent UUID types for user input / avoid 500 errors (#1309)

Fixes #1297 
Ensures proper typing for UUIDs in FastAPI input models, to avoid
explicit conversions, which may throw errors.
This avoids possible 500 errors (due to ValueError exceptions) when
converting UUIDs from user input.
Instead, will get more 422 errors from FastAPI. 

UUID conversions remaining are in operator / profile handling where
UUIDs are retrieved from previously set fields, remaining user input
conversions in user auth and collection list are wrapped in exceptions.

For `profileid`, update fastapi models to support union of UUID, null,
and EmptyStr (new empty string only type), to differentiate removing
profile (empty string) vs not changing at all (null) for config updates
This commit is contained in:
Ilya Kreymer 2023-10-25 12:15:53 -07:00 committed by GitHub
parent 4b9ca44adb
commit 4591db1afe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 98 additions and 75 deletions

View File

@ -144,22 +144,18 @@ def init_jwt_auth(user_manager):
oauth2_scheme = OA2BearerOrQuery(tokenUrl="/api/auth/jwt/login", auto_error=False)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=401,
detail="invalid_credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_jwt(token, AUTH_ALLOW_AUD)
uid: Optional[str] = payload.get("sub") or payload.get("user_id")
if uid is None:
raise credentials_exception
except Exception:
raise credentials_exception
user = await user_manager.get_by_id(uuid.UUID(uid))
if user is None:
raise credentials_exception
return user
user = await user_manager.get_by_id(uuid.UUID(uid))
assert user
return user
except:
raise HTTPException(
status_code=401,
detail="invalid_credentials",
headers={"WWW-Authenticate": "Bearer"},
)
current_active_user = get_current_user

View File

@ -23,6 +23,7 @@ from .models import (
CrawlConfig,
CrawlConfigOut,
CrawlConfigIdNameOut,
EmptyStr,
UpdateCrawlConfig,
Organization,
User,
@ -115,14 +116,15 @@ class CrawlConfigOps:
"""sanitize string for use in wacz filename"""
return self._file_rx.sub("-", string.lower())
async def _lookup_profile(self, profileid, org):
async def _lookup_profile(
self, profileid: Union[uuid.UUID, EmptyStr, None], org: Organization
) -> tuple[Optional[uuid.UUID], Optional[str]]:
if profileid is None:
return None, None
if profileid == "":
if isinstance(profileid, EmptyStr) or profileid == "":
return None, ""
profileid = uuid.UUID(profileid)
profile_filename = await self.profiles.get_profile_storage_path(profileid, org)
if not profile_filename:
raise HTTPException(status_code=400, detail="invalid_profile_id")
@ -135,7 +137,7 @@ class CrawlConfigOps:
config: CrawlConfigIn,
org: Organization,
user: User,
):
) -> tuple[str, str, bool]:
"""Add new crawl config"""
data = config.dict()
data["oid"] = org.id
@ -183,7 +185,7 @@ class CrawlConfigOps:
storage=org.storage,
run_now=run_now,
out_filename=out_filename,
profile_filename=profile_filename,
profile_filename=profile_filename or "",
)
if crawl_id and run_now:
@ -227,7 +229,7 @@ class CrawlConfigOps:
async def update_crawl_config(
self, cid: uuid.UUID, org: Organization, user: User, update: UpdateCrawlConfig
):
) -> dict[str, bool]:
# pylint: disable=too-many-locals
"""Update name, scale, schedule, and/or tags for an existing crawl config"""
@ -336,7 +338,7 @@ class CrawlConfigOps:
"metadata_changed": metadata_changed,
}
if run_now:
crawl_id = await self.run_now(str(cid), org, user)
crawl_id = await self.run_now(cid, org, user)
ret["started"] = crawl_id
return ret
@ -728,9 +730,9 @@ class CrawlConfigOps:
"workflowIds": workflow_ids,
}
async def run_now(self, cid: str, org: Organization, user: User):
async def run_now(self, cid: uuid.UUID, org: Organization, user: User):
"""run specified crawlconfig now"""
crawlconfig = await self.get_crawl_config(uuid.UUID(cid), org.id)
crawlconfig = await self.get_crawl_config(cid, org.id)
if not crawlconfig:
raise HTTPException(
@ -957,29 +959,29 @@ def init_crawl_config_api(
@router.get("/{cid}/seeds", response_model=PaginatedResponse)
async def get_crawl_config_seeds(
cid: str,
cid: uuid.UUID,
org: Organization = Depends(org_viewer_dep),
pageSize: int = DEFAULT_PAGE_SIZE,
page: int = 1,
):
seeds, total = await ops.get_seeds(uuid.UUID(cid), org.id, pageSize, page)
seeds, total = await ops.get_seeds(cid, org.id, pageSize, page)
return paginated_format(seeds, total, page, pageSize)
@router.get("/{cid}", response_model=CrawlConfigOut)
async def get_crawl_config_out(
cid: str, org: Organization = Depends(org_viewer_dep)
cid: uuid.UUID, org: Organization = Depends(org_viewer_dep)
):
return await ops.get_crawl_config_out(uuid.UUID(cid), org)
return await ops.get_crawl_config_out(cid, org)
@router.get(
"/{cid}/revs",
dependencies=[Depends(org_viewer_dep)],
)
async def get_crawl_config_revisions(
cid: str, pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1
cid: uuid.UUID, pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1
):
revisions, total = await ops.get_crawl_config_revs(
uuid.UUID(cid), page_size=pageSize, page=page
cid, page_size=pageSize, page=page
)
return paginated_format(revisions, total, page, pageSize)
@ -1000,24 +1002,24 @@ def init_crawl_config_api(
@router.patch("/{cid}", dependencies=[Depends(org_crawl_dep)])
async def update_crawl_config(
update: UpdateCrawlConfig,
cid: str,
cid: uuid.UUID,
org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep),
):
return await ops.update_crawl_config(uuid.UUID(cid), org, user, update)
return await ops.update_crawl_config(cid, org, user, update)
@router.post("/{cid}/run")
async def run_now(
cid: str,
cid: uuid.UUID,
org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep),
):
) -> dict[str, str]:
crawl_id = await ops.run_now(cid, org, user)
return {"started": crawl_id}
@router.delete("/{cid}")
async def make_inactive(cid: str, org: Organization = Depends(org_crawl_dep)):
crawlconfig = await ops.get_crawl_config(uuid.UUID(cid), org.id)
async def make_inactive(cid: uuid.UUID, org: Organization = Depends(org_crawl_dep)):
crawlconfig = await ops.get_crawl_config(cid, org.id)
if not crawlconfig:
raise HTTPException(

View File

@ -141,7 +141,7 @@ class CrawlManager(K8sAPI):
has_scale_update
or has_config_update
or has_timeout_update
or profile_filename
or profile_filename is not None
or has_max_crawl_size_update
):
await self._update_config_map(

View File

@ -7,7 +7,16 @@ from enum import Enum, IntEnum
import os
from typing import Optional, List, Dict, Union, Literal, Any
from pydantic import BaseModel, UUID4, conint, Field, HttpUrl, AnyHttpUrl, EmailStr
from pydantic import (
BaseModel,
UUID4,
conint,
Field,
HttpUrl,
AnyHttpUrl,
EmailStr,
ConstrainedStr,
)
# from fastapi_users import models as fastapi_users_models
@ -165,6 +174,14 @@ class ScopeType(str, Enum):
CUSTOM = "custom"
# ============================================================================
class EmptyStr(ConstrainedStr):
"""empty string only"""
min_length = 0
max_length = 0
# ============================================================================
class Seed(BaseModel):
"""Crawl seed"""
@ -231,7 +248,7 @@ class CrawlConfigIn(BaseModel):
jobType: Optional[JobType] = JobType.CUSTOM
profileid: Optional[str]
profileid: Union[UUID4, EmptyStr, None]
autoAddCollections: Optional[List[UUID4]] = []
tags: Optional[List[str]] = []
@ -372,7 +389,7 @@ class UpdateCrawlConfig(BaseModel):
# crawl data: revision tracked
schedule: Optional[str] = None
profileid: Optional[str] = None
profileid: Union[UUID4, EmptyStr, None] = None
crawlTimeout: Optional[int] = None
maxCrawlSize: Optional[int] = None
scale: Optional[conint(ge=1, le=MAX_CRAWL_SCALE)] = None # type: ignore

View File

@ -431,12 +431,10 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
ops = OrgOps(mdb, invites)
async def org_dep(oid: str, user: User = Depends(user_dep)):
org = await ops.get_org_for_user_by_id(uuid.UUID(oid), user)
async def org_dep(oid: uuid.UUID, user: User = Depends(user_dep)):
org = await ops.get_org_for_user_by_id(oid, user)
if not org:
raise HTTPException(
status_code=404, detail=f"Organization '{oid}' not found"
)
raise HTTPException(status_code=404, detail="org_not_found")
if not org.is_viewer(user):
raise HTTPException(
status_code=403,
@ -466,8 +464,8 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
return org
async def org_public(oid: str):
org = await ops.get_org_by_id(uuid.UUID(oid))
async def org_public(oid: uuid.UUID):
org = await ops.get_org_by_id(oid)
if not org:
raise HTTPException(status_code=404, detail="org_not_found")
@ -651,7 +649,7 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
@router.post("/remove", tags=["invites"])
async def remove_user_from_org(
remove: RemoveFromOrg, org: Organization = Depends(org_owner_dep)
):
) -> dict[str, bool]:
other_user = await user_manager.get_by_email(remove.email)
if org.is_owner(other_user):

View File

@ -161,7 +161,8 @@ class ProfileOps:
await self.crawl_manager.delete_profile_browser(browser_commit.browserid)
file_size = resource["bytes"]
# backwards compatibility
file_size = resource.get("size") or resource.get("bytes")
profile_file = ProfileFile(
hash=resource["hash"],

View File

@ -8,7 +8,7 @@ from urllib.parse import unquote
import asyncio
from io import BufferedReader
from typing import Optional, List
from typing import Optional, List, Any
from fastapi import Depends, UploadFile, File
from fastapi import HTTPException
from pydantic import UUID4
@ -63,12 +63,12 @@ class UploadOps(BaseCrawlOps):
filename: str,
name: Optional[str],
description: Optional[str],
collections: Optional[List[UUID4]],
collections: Optional[List[str]],
tags: Optional[List[str]],
org: Organization,
user: User,
replaceId: Optional[str],
):
) -> dict[str, Any]:
"""Upload streaming file, length unknown"""
if await self.orgs.storage_quota_reached(org.id):
raise HTTPException(status_code=403, detail="storage_quota_reached")
@ -122,11 +122,11 @@ class UploadOps(BaseCrawlOps):
uploads: List[UploadFile],
name: Optional[str],
description: Optional[str],
collections: Optional[List[UUID4]],
collections: Optional[List[str]],
tags: Optional[List[str]],
org: Organization,
user: User,
):
) -> dict[str, Any]:
"""handle uploading content to uploads subdir + request subdir"""
if await self.orgs.storage_quota_reached(org.id):
raise HTTPException(status_code=403, detail="storage_quota_reached")
@ -145,22 +145,31 @@ class UploadOps(BaseCrawlOps):
files.append(file_reader.file_prep.get_crawl_file())
return await self._create_upload(
files, name, description, collections, tags, id_, org, user
files, name, description, collections, tags, str(id_), org, user
)
async def _create_upload(
self, files, name, description, collections, tags, id_, org, user
):
self,
files: List[UploadFile],
name: Optional[str],
description: Optional[str],
collections: Optional[List[str]],
tags: Optional[List[str]],
crawl_id: str,
org: Organization,
user: User,
) -> dict[str, Any]:
now = dt_now()
# ts_now = now.strftime("%Y%m%d%H%M%S")
# crawl_id = f"upload-{ts_now}-{str(id_)[:12]}"
crawl_id = str(id_)
file_size = sum(file_.size or 0 for file_ in files)
file_size = sum(file_.size for file_ in files)
collection_uuids = []
for coll in collections:
collection_uuids.append(uuid.UUID(coll))
collection_uuids: List[uuid.UUID] = []
if collections:
try:
for coll in collections:
collection_uuids.append(uuid.UUID(coll))
# pylint: disable=raise-missing-from
except:
raise HTTPException(status_code=400, detail="invalid_collection_id")
uploaded = UploadedCrawl(
id=crawl_id,
@ -299,7 +308,7 @@ def init_uploads_api(
tags: Optional[str] = "",
org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep),
):
) -> dict[str, Any]:
name = unquote(name)
description = unquote(description)
colls_list = []
@ -325,7 +334,7 @@ def init_uploads_api(
replaceId: Optional[str] = "",
org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep),
):
) -> dict[str, Any]:
name = unquote(name)
description = unquote(description)
colls_list = []

View File

@ -6,9 +6,9 @@ import os
import uuid
import asyncio
from typing import Optional
from typing import Optional, List
from pydantic import UUID4, EmailStr
from pydantic import EmailStr
from fastapi import (
Request,
@ -178,11 +178,11 @@ class UserManager:
return None
async def get_user_names_by_ids(self, user_ids):
async def get_user_names_by_ids(self, user_ids: List[str]) -> dict[str, str]:
"""return list of user names for given ids"""
user_ids = [UUID4(id_) for id_ in user_ids]
user_uuid_ids = [uuid.UUID(id_) for id_ in user_ids]
cursor = self.users.find(
{"id": {"$in": user_ids}}, projection=["id", "name", "email"]
{"id": {"$in": user_uuid_ids}}, projection=["id", "name", "email"]
)
return await cursor.to_list(length=1000)
@ -369,7 +369,7 @@ class UserManager:
return user
async def get_by_id(self, _id: UUID4) -> Optional[User]:
async def get_by_id(self, _id: uuid.UUID) -> Optional[User]:
"""get user by unique id"""
user = await self.users.find_one({"id": _id})
@ -409,7 +409,7 @@ class UserManager:
raise exc
try:
user_uuid = UUID4(user_id)
user_uuid = uuid.UUID(user_id)
except ValueError:
raise exc
@ -456,7 +456,7 @@ class UserManager:
user_id = data["user_id"]
try:
user_uuid = UUID4(user_id)
user_uuid = uuid.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=400,
@ -699,8 +699,8 @@ def init_users_router(current_active_user, user_manager) -> APIRouter:
return await user_manager.format_invite(invite)
@users_router.get("/invite/{token}", tags=["invites"])
async def get_invite_info(token: str, email: str):
invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email)
async def get_invite_info(token: uuid.UUID, email: str):
invite = await user_manager.invites.get_valid_invite(token, email)
return await user_manager.format_invite(invite)
# pylint: disable=invalid-name