More stringent UUID types for user input / avoid 500 errors (#1309)

Fixes #1297 
Ensures proper typing for UUIDs in FastAPI input models, to avoid
explicit conversions, which may throw errors.
This avoids possible 500 errors (due to ValueError exceptions) when
converting UUIDs from user input.
Instead, will get more 422 errors from FastAPI. 

UUID conversions remaining are in operator / profile handling where
UUIDs are retrieved from previously set fields, remaining user input
conversions in user auth and collection list are wrapped in exceptions.

For `profileid`, update fastapi models to support union of UUID, null,
and EmptyStr (new empty string only type), to differentiate removing
profile (empty string) vs not changing at all (null) for config updates
This commit is contained in:
Ilya Kreymer 2023-10-25 12:15:53 -07:00 committed by GitHub
parent 4b9ca44adb
commit 4591db1afe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 98 additions and 75 deletions

View File

@ -144,22 +144,18 @@ def init_jwt_auth(user_manager):
oauth2_scheme = OA2BearerOrQuery(tokenUrl="/api/auth/jwt/login", auto_error=False) oauth2_scheme = OA2BearerOrQuery(tokenUrl="/api/auth/jwt/login", auto_error=False)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=401,
detail="invalid_credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try: try:
payload = decode_jwt(token, AUTH_ALLOW_AUD) payload = decode_jwt(token, AUTH_ALLOW_AUD)
uid: Optional[str] = payload.get("sub") or payload.get("user_id") uid: Optional[str] = payload.get("sub") or payload.get("user_id")
if uid is None: user = await user_manager.get_by_id(uuid.UUID(uid))
raise credentials_exception assert user
except Exception: return user
raise credentials_exception except:
user = await user_manager.get_by_id(uuid.UUID(uid)) raise HTTPException(
if user is None: status_code=401,
raise credentials_exception detail="invalid_credentials",
return user headers={"WWW-Authenticate": "Bearer"},
)
current_active_user = get_current_user current_active_user = get_current_user

View File

@ -23,6 +23,7 @@ from .models import (
CrawlConfig, CrawlConfig,
CrawlConfigOut, CrawlConfigOut,
CrawlConfigIdNameOut, CrawlConfigIdNameOut,
EmptyStr,
UpdateCrawlConfig, UpdateCrawlConfig,
Organization, Organization,
User, User,
@ -115,14 +116,15 @@ class CrawlConfigOps:
"""sanitize string for use in wacz filename""" """sanitize string for use in wacz filename"""
return self._file_rx.sub("-", string.lower()) return self._file_rx.sub("-", string.lower())
async def _lookup_profile(self, profileid, org): async def _lookup_profile(
self, profileid: Union[uuid.UUID, EmptyStr, None], org: Organization
) -> tuple[Optional[uuid.UUID], Optional[str]]:
if profileid is None: if profileid is None:
return None, None return None, None
if profileid == "": if isinstance(profileid, EmptyStr) or profileid == "":
return None, "" return None, ""
profileid = uuid.UUID(profileid)
profile_filename = await self.profiles.get_profile_storage_path(profileid, org) profile_filename = await self.profiles.get_profile_storage_path(profileid, org)
if not profile_filename: if not profile_filename:
raise HTTPException(status_code=400, detail="invalid_profile_id") raise HTTPException(status_code=400, detail="invalid_profile_id")
@ -135,7 +137,7 @@ class CrawlConfigOps:
config: CrawlConfigIn, config: CrawlConfigIn,
org: Organization, org: Organization,
user: User, user: User,
): ) -> tuple[str, str, bool]:
"""Add new crawl config""" """Add new crawl config"""
data = config.dict() data = config.dict()
data["oid"] = org.id data["oid"] = org.id
@ -183,7 +185,7 @@ class CrawlConfigOps:
storage=org.storage, storage=org.storage,
run_now=run_now, run_now=run_now,
out_filename=out_filename, out_filename=out_filename,
profile_filename=profile_filename, profile_filename=profile_filename or "",
) )
if crawl_id and run_now: if crawl_id and run_now:
@ -227,7 +229,7 @@ class CrawlConfigOps:
async def update_crawl_config( async def update_crawl_config(
self, cid: uuid.UUID, org: Organization, user: User, update: UpdateCrawlConfig self, cid: uuid.UUID, org: Organization, user: User, update: UpdateCrawlConfig
): ) -> dict[str, bool]:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
"""Update name, scale, schedule, and/or tags for an existing crawl config""" """Update name, scale, schedule, and/or tags for an existing crawl config"""
@ -336,7 +338,7 @@ class CrawlConfigOps:
"metadata_changed": metadata_changed, "metadata_changed": metadata_changed,
} }
if run_now: if run_now:
crawl_id = await self.run_now(str(cid), org, user) crawl_id = await self.run_now(cid, org, user)
ret["started"] = crawl_id ret["started"] = crawl_id
return ret return ret
@ -728,9 +730,9 @@ class CrawlConfigOps:
"workflowIds": workflow_ids, "workflowIds": workflow_ids,
} }
async def run_now(self, cid: str, org: Organization, user: User): async def run_now(self, cid: uuid.UUID, org: Organization, user: User):
"""run specified crawlconfig now""" """run specified crawlconfig now"""
crawlconfig = await self.get_crawl_config(uuid.UUID(cid), org.id) crawlconfig = await self.get_crawl_config(cid, org.id)
if not crawlconfig: if not crawlconfig:
raise HTTPException( raise HTTPException(
@ -957,29 +959,29 @@ def init_crawl_config_api(
@router.get("/{cid}/seeds", response_model=PaginatedResponse) @router.get("/{cid}/seeds", response_model=PaginatedResponse)
async def get_crawl_config_seeds( async def get_crawl_config_seeds(
cid: str, cid: uuid.UUID,
org: Organization = Depends(org_viewer_dep), org: Organization = Depends(org_viewer_dep),
pageSize: int = DEFAULT_PAGE_SIZE, pageSize: int = DEFAULT_PAGE_SIZE,
page: int = 1, page: int = 1,
): ):
seeds, total = await ops.get_seeds(uuid.UUID(cid), org.id, pageSize, page) seeds, total = await ops.get_seeds(cid, org.id, pageSize, page)
return paginated_format(seeds, total, page, pageSize) return paginated_format(seeds, total, page, pageSize)
@router.get("/{cid}", response_model=CrawlConfigOut) @router.get("/{cid}", response_model=CrawlConfigOut)
async def get_crawl_config_out( async def get_crawl_config_out(
cid: str, org: Organization = Depends(org_viewer_dep) cid: uuid.UUID, org: Organization = Depends(org_viewer_dep)
): ):
return await ops.get_crawl_config_out(uuid.UUID(cid), org) return await ops.get_crawl_config_out(cid, org)
@router.get( @router.get(
"/{cid}/revs", "/{cid}/revs",
dependencies=[Depends(org_viewer_dep)], dependencies=[Depends(org_viewer_dep)],
) )
async def get_crawl_config_revisions( async def get_crawl_config_revisions(
cid: str, pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1 cid: uuid.UUID, pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1
): ):
revisions, total = await ops.get_crawl_config_revs( revisions, total = await ops.get_crawl_config_revs(
uuid.UUID(cid), page_size=pageSize, page=page cid, page_size=pageSize, page=page
) )
return paginated_format(revisions, total, page, pageSize) return paginated_format(revisions, total, page, pageSize)
@ -1000,24 +1002,24 @@ def init_crawl_config_api(
@router.patch("/{cid}", dependencies=[Depends(org_crawl_dep)]) @router.patch("/{cid}", dependencies=[Depends(org_crawl_dep)])
async def update_crawl_config( async def update_crawl_config(
update: UpdateCrawlConfig, update: UpdateCrawlConfig,
cid: str, cid: uuid.UUID,
org: Organization = Depends(org_crawl_dep), org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep), user: User = Depends(user_dep),
): ):
return await ops.update_crawl_config(uuid.UUID(cid), org, user, update) return await ops.update_crawl_config(cid, org, user, update)
@router.post("/{cid}/run") @router.post("/{cid}/run")
async def run_now( async def run_now(
cid: str, cid: uuid.UUID,
org: Organization = Depends(org_crawl_dep), org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep), user: User = Depends(user_dep),
): ) -> dict[str, str]:
crawl_id = await ops.run_now(cid, org, user) crawl_id = await ops.run_now(cid, org, user)
return {"started": crawl_id} return {"started": crawl_id}
@router.delete("/{cid}") @router.delete("/{cid}")
async def make_inactive(cid: str, org: Organization = Depends(org_crawl_dep)): async def make_inactive(cid: uuid.UUID, org: Organization = Depends(org_crawl_dep)):
crawlconfig = await ops.get_crawl_config(uuid.UUID(cid), org.id) crawlconfig = await ops.get_crawl_config(cid, org.id)
if not crawlconfig: if not crawlconfig:
raise HTTPException( raise HTTPException(

View File

@ -141,7 +141,7 @@ class CrawlManager(K8sAPI):
has_scale_update has_scale_update
or has_config_update or has_config_update
or has_timeout_update or has_timeout_update
or profile_filename or profile_filename is not None
or has_max_crawl_size_update or has_max_crawl_size_update
): ):
await self._update_config_map( await self._update_config_map(

View File

@ -7,7 +7,16 @@ from enum import Enum, IntEnum
import os import os
from typing import Optional, List, Dict, Union, Literal, Any from typing import Optional, List, Dict, Union, Literal, Any
from pydantic import BaseModel, UUID4, conint, Field, HttpUrl, AnyHttpUrl, EmailStr from pydantic import (
BaseModel,
UUID4,
conint,
Field,
HttpUrl,
AnyHttpUrl,
EmailStr,
ConstrainedStr,
)
# from fastapi_users import models as fastapi_users_models # from fastapi_users import models as fastapi_users_models
@ -165,6 +174,14 @@ class ScopeType(str, Enum):
CUSTOM = "custom" CUSTOM = "custom"
# ============================================================================
class EmptyStr(ConstrainedStr):
"""empty string only"""
min_length = 0
max_length = 0
# ============================================================================ # ============================================================================
class Seed(BaseModel): class Seed(BaseModel):
"""Crawl seed""" """Crawl seed"""
@ -231,7 +248,7 @@ class CrawlConfigIn(BaseModel):
jobType: Optional[JobType] = JobType.CUSTOM jobType: Optional[JobType] = JobType.CUSTOM
profileid: Optional[str] profileid: Union[UUID4, EmptyStr, None]
autoAddCollections: Optional[List[UUID4]] = [] autoAddCollections: Optional[List[UUID4]] = []
tags: Optional[List[str]] = [] tags: Optional[List[str]] = []
@ -372,7 +389,7 @@ class UpdateCrawlConfig(BaseModel):
# crawl data: revision tracked # crawl data: revision tracked
schedule: Optional[str] = None schedule: Optional[str] = None
profileid: Optional[str] = None profileid: Union[UUID4, EmptyStr, None] = None
crawlTimeout: Optional[int] = None crawlTimeout: Optional[int] = None
maxCrawlSize: Optional[int] = None maxCrawlSize: Optional[int] = None
scale: Optional[conint(ge=1, le=MAX_CRAWL_SCALE)] = None # type: ignore scale: Optional[conint(ge=1, le=MAX_CRAWL_SCALE)] = None # type: ignore

View File

@ -431,12 +431,10 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
ops = OrgOps(mdb, invites) ops = OrgOps(mdb, invites)
async def org_dep(oid: str, user: User = Depends(user_dep)): async def org_dep(oid: uuid.UUID, user: User = Depends(user_dep)):
org = await ops.get_org_for_user_by_id(uuid.UUID(oid), user) org = await ops.get_org_for_user_by_id(oid, user)
if not org: if not org:
raise HTTPException( raise HTTPException(status_code=404, detail="org_not_found")
status_code=404, detail=f"Organization '{oid}' not found"
)
if not org.is_viewer(user): if not org.is_viewer(user):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -466,8 +464,8 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
return org return org
async def org_public(oid: str): async def org_public(oid: uuid.UUID):
org = await ops.get_org_by_id(uuid.UUID(oid)) org = await ops.get_org_by_id(oid)
if not org: if not org:
raise HTTPException(status_code=404, detail="org_not_found") raise HTTPException(status_code=404, detail="org_not_found")
@ -651,7 +649,7 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
@router.post("/remove", tags=["invites"]) @router.post("/remove", tags=["invites"])
async def remove_user_from_org( async def remove_user_from_org(
remove: RemoveFromOrg, org: Organization = Depends(org_owner_dep) remove: RemoveFromOrg, org: Organization = Depends(org_owner_dep)
): ) -> dict[str, bool]:
other_user = await user_manager.get_by_email(remove.email) other_user = await user_manager.get_by_email(remove.email)
if org.is_owner(other_user): if org.is_owner(other_user):

View File

@ -161,7 +161,8 @@ class ProfileOps:
await self.crawl_manager.delete_profile_browser(browser_commit.browserid) await self.crawl_manager.delete_profile_browser(browser_commit.browserid)
file_size = resource["bytes"] # backwards compatibility
file_size = resource.get("size") or resource.get("bytes")
profile_file = ProfileFile( profile_file = ProfileFile(
hash=resource["hash"], hash=resource["hash"],

View File

@ -8,7 +8,7 @@ from urllib.parse import unquote
import asyncio import asyncio
from io import BufferedReader from io import BufferedReader
from typing import Optional, List from typing import Optional, List, Any
from fastapi import Depends, UploadFile, File from fastapi import Depends, UploadFile, File
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import UUID4 from pydantic import UUID4
@ -63,12 +63,12 @@ class UploadOps(BaseCrawlOps):
filename: str, filename: str,
name: Optional[str], name: Optional[str],
description: Optional[str], description: Optional[str],
collections: Optional[List[UUID4]], collections: Optional[List[str]],
tags: Optional[List[str]], tags: Optional[List[str]],
org: Organization, org: Organization,
user: User, user: User,
replaceId: Optional[str], replaceId: Optional[str],
): ) -> dict[str, Any]:
"""Upload streaming file, length unknown""" """Upload streaming file, length unknown"""
if await self.orgs.storage_quota_reached(org.id): if await self.orgs.storage_quota_reached(org.id):
raise HTTPException(status_code=403, detail="storage_quota_reached") raise HTTPException(status_code=403, detail="storage_quota_reached")
@ -122,11 +122,11 @@ class UploadOps(BaseCrawlOps):
uploads: List[UploadFile], uploads: List[UploadFile],
name: Optional[str], name: Optional[str],
description: Optional[str], description: Optional[str],
collections: Optional[List[UUID4]], collections: Optional[List[str]],
tags: Optional[List[str]], tags: Optional[List[str]],
org: Organization, org: Organization,
user: User, user: User,
): ) -> dict[str, Any]:
"""handle uploading content to uploads subdir + request subdir""" """handle uploading content to uploads subdir + request subdir"""
if await self.orgs.storage_quota_reached(org.id): if await self.orgs.storage_quota_reached(org.id):
raise HTTPException(status_code=403, detail="storage_quota_reached") raise HTTPException(status_code=403, detail="storage_quota_reached")
@ -145,22 +145,31 @@ class UploadOps(BaseCrawlOps):
files.append(file_reader.file_prep.get_crawl_file()) files.append(file_reader.file_prep.get_crawl_file())
return await self._create_upload( return await self._create_upload(
files, name, description, collections, tags, id_, org, user files, name, description, collections, tags, str(id_), org, user
) )
async def _create_upload( async def _create_upload(
self, files, name, description, collections, tags, id_, org, user self,
): files: List[UploadFile],
name: Optional[str],
description: Optional[str],
collections: Optional[List[str]],
tags: Optional[List[str]],
crawl_id: str,
org: Organization,
user: User,
) -> dict[str, Any]:
now = dt_now() now = dt_now()
# ts_now = now.strftime("%Y%m%d%H%M%S") file_size = sum(file_.size or 0 for file_ in files)
# crawl_id = f"upload-{ts_now}-{str(id_)[:12]}"
crawl_id = str(id_)
file_size = sum(file_.size for file_ in files) collection_uuids: List[uuid.UUID] = []
if collections:
collection_uuids = [] try:
for coll in collections: for coll in collections:
collection_uuids.append(uuid.UUID(coll)) collection_uuids.append(uuid.UUID(coll))
# pylint: disable=raise-missing-from
except:
raise HTTPException(status_code=400, detail="invalid_collection_id")
uploaded = UploadedCrawl( uploaded = UploadedCrawl(
id=crawl_id, id=crawl_id,
@ -299,7 +308,7 @@ def init_uploads_api(
tags: Optional[str] = "", tags: Optional[str] = "",
org: Organization = Depends(org_crawl_dep), org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep), user: User = Depends(user_dep),
): ) -> dict[str, Any]:
name = unquote(name) name = unquote(name)
description = unquote(description) description = unquote(description)
colls_list = [] colls_list = []
@ -325,7 +334,7 @@ def init_uploads_api(
replaceId: Optional[str] = "", replaceId: Optional[str] = "",
org: Organization = Depends(org_crawl_dep), org: Organization = Depends(org_crawl_dep),
user: User = Depends(user_dep), user: User = Depends(user_dep),
): ) -> dict[str, Any]:
name = unquote(name) name = unquote(name)
description = unquote(description) description = unquote(description)
colls_list = [] colls_list = []

View File

@ -6,9 +6,9 @@ import os
import uuid import uuid
import asyncio import asyncio
from typing import Optional from typing import Optional, List
from pydantic import UUID4, EmailStr from pydantic import EmailStr
from fastapi import ( from fastapi import (
Request, Request,
@ -178,11 +178,11 @@ class UserManager:
return None return None
async def get_user_names_by_ids(self, user_ids): async def get_user_names_by_ids(self, user_ids: List[str]) -> dict[str, str]:
"""return list of user names for given ids""" """return list of user names for given ids"""
user_ids = [UUID4(id_) for id_ in user_ids] user_uuid_ids = [uuid.UUID(id_) for id_ in user_ids]
cursor = self.users.find( cursor = self.users.find(
{"id": {"$in": user_ids}}, projection=["id", "name", "email"] {"id": {"$in": user_uuid_ids}}, projection=["id", "name", "email"]
) )
return await cursor.to_list(length=1000) return await cursor.to_list(length=1000)
@ -369,7 +369,7 @@ class UserManager:
return user return user
async def get_by_id(self, _id: UUID4) -> Optional[User]: async def get_by_id(self, _id: uuid.UUID) -> Optional[User]:
"""get user by unique id""" """get user by unique id"""
user = await self.users.find_one({"id": _id}) user = await self.users.find_one({"id": _id})
@ -409,7 +409,7 @@ class UserManager:
raise exc raise exc
try: try:
user_uuid = UUID4(user_id) user_uuid = uuid.UUID(user_id)
except ValueError: except ValueError:
raise exc raise exc
@ -456,7 +456,7 @@ class UserManager:
user_id = data["user_id"] user_id = data["user_id"]
try: try:
user_uuid = UUID4(user_id) user_uuid = uuid.UUID(user_id)
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -699,8 +699,8 @@ def init_users_router(current_active_user, user_manager) -> APIRouter:
return await user_manager.format_invite(invite) return await user_manager.format_invite(invite)
@users_router.get("/invite/{token}", tags=["invites"]) @users_router.get("/invite/{token}", tags=["invites"])
async def get_invite_info(token: str, email: str): async def get_invite_info(token: uuid.UUID, email: str):
invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email) invite = await user_manager.invites.get_valid_invite(token, email)
return await user_manager.format_invite(invite) return await user_manager.format_invite(invite)
# pylint: disable=invalid-name # pylint: disable=invalid-name