More stringent UUID types for user input / avoid 500 errors (#1309)
Fixes #1297 Ensures proper typing for UUIDs in FastAPI input models, to avoid explicit conversions, which may throw errors. This avoids possible 500 errors (due to ValueError exceptions) when converting UUIDs from user input. Instead, will get more 422 errors from FastAPI. UUID conversions remaining are in operator / profile handling where UUIDs are retrieved from previously set fields, remaining user input conversions in user auth and collection list are wrapped in exceptions. For `profileid`, update fastapi models to support union of UUID, null, and EmptyStr (new empty string only type), to differentiate removing profile (empty string) vs not changing at all (null) for config updates
This commit is contained in:
parent
4b9ca44adb
commit
4591db1afe
@ -144,22 +144,18 @@ def init_jwt_auth(user_manager):
|
||||
oauth2_scheme = OA2BearerOrQuery(tokenUrl="/api/auth/jwt/login", auto_error=False)
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="invalid_credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_jwt(token, AUTH_ALLOW_AUD)
|
||||
uid: Optional[str] = payload.get("sub") or payload.get("user_id")
|
||||
if uid is None:
|
||||
raise credentials_exception
|
||||
except Exception:
|
||||
raise credentials_exception
|
||||
user = await user_manager.get_by_id(uuid.UUID(uid))
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
user = await user_manager.get_by_id(uuid.UUID(uid))
|
||||
assert user
|
||||
return user
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="invalid_credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
current_active_user = get_current_user
|
||||
|
||||
|
@ -23,6 +23,7 @@ from .models import (
|
||||
CrawlConfig,
|
||||
CrawlConfigOut,
|
||||
CrawlConfigIdNameOut,
|
||||
EmptyStr,
|
||||
UpdateCrawlConfig,
|
||||
Organization,
|
||||
User,
|
||||
@ -115,14 +116,15 @@ class CrawlConfigOps:
|
||||
"""sanitize string for use in wacz filename"""
|
||||
return self._file_rx.sub("-", string.lower())
|
||||
|
||||
async def _lookup_profile(self, profileid, org):
|
||||
async def _lookup_profile(
|
||||
self, profileid: Union[uuid.UUID, EmptyStr, None], org: Organization
|
||||
) -> tuple[Optional[uuid.UUID], Optional[str]]:
|
||||
if profileid is None:
|
||||
return None, None
|
||||
|
||||
if profileid == "":
|
||||
if isinstance(profileid, EmptyStr) or profileid == "":
|
||||
return None, ""
|
||||
|
||||
profileid = uuid.UUID(profileid)
|
||||
profile_filename = await self.profiles.get_profile_storage_path(profileid, org)
|
||||
if not profile_filename:
|
||||
raise HTTPException(status_code=400, detail="invalid_profile_id")
|
||||
@ -135,7 +137,7 @@ class CrawlConfigOps:
|
||||
config: CrawlConfigIn,
|
||||
org: Organization,
|
||||
user: User,
|
||||
):
|
||||
) -> tuple[str, str, bool]:
|
||||
"""Add new crawl config"""
|
||||
data = config.dict()
|
||||
data["oid"] = org.id
|
||||
@ -183,7 +185,7 @@ class CrawlConfigOps:
|
||||
storage=org.storage,
|
||||
run_now=run_now,
|
||||
out_filename=out_filename,
|
||||
profile_filename=profile_filename,
|
||||
profile_filename=profile_filename or "",
|
||||
)
|
||||
|
||||
if crawl_id and run_now:
|
||||
@ -227,7 +229,7 @@ class CrawlConfigOps:
|
||||
|
||||
async def update_crawl_config(
|
||||
self, cid: uuid.UUID, org: Organization, user: User, update: UpdateCrawlConfig
|
||||
):
|
||||
) -> dict[str, bool]:
|
||||
# pylint: disable=too-many-locals
|
||||
"""Update name, scale, schedule, and/or tags for an existing crawl config"""
|
||||
|
||||
@ -336,7 +338,7 @@ class CrawlConfigOps:
|
||||
"metadata_changed": metadata_changed,
|
||||
}
|
||||
if run_now:
|
||||
crawl_id = await self.run_now(str(cid), org, user)
|
||||
crawl_id = await self.run_now(cid, org, user)
|
||||
ret["started"] = crawl_id
|
||||
return ret
|
||||
|
||||
@ -728,9 +730,9 @@ class CrawlConfigOps:
|
||||
"workflowIds": workflow_ids,
|
||||
}
|
||||
|
||||
async def run_now(self, cid: str, org: Organization, user: User):
|
||||
async def run_now(self, cid: uuid.UUID, org: Organization, user: User):
|
||||
"""run specified crawlconfig now"""
|
||||
crawlconfig = await self.get_crawl_config(uuid.UUID(cid), org.id)
|
||||
crawlconfig = await self.get_crawl_config(cid, org.id)
|
||||
|
||||
if not crawlconfig:
|
||||
raise HTTPException(
|
||||
@ -957,29 +959,29 @@ def init_crawl_config_api(
|
||||
|
||||
@router.get("/{cid}/seeds", response_model=PaginatedResponse)
|
||||
async def get_crawl_config_seeds(
|
||||
cid: str,
|
||||
cid: uuid.UUID,
|
||||
org: Organization = Depends(org_viewer_dep),
|
||||
pageSize: int = DEFAULT_PAGE_SIZE,
|
||||
page: int = 1,
|
||||
):
|
||||
seeds, total = await ops.get_seeds(uuid.UUID(cid), org.id, pageSize, page)
|
||||
seeds, total = await ops.get_seeds(cid, org.id, pageSize, page)
|
||||
return paginated_format(seeds, total, page, pageSize)
|
||||
|
||||
@router.get("/{cid}", response_model=CrawlConfigOut)
|
||||
async def get_crawl_config_out(
|
||||
cid: str, org: Organization = Depends(org_viewer_dep)
|
||||
cid: uuid.UUID, org: Organization = Depends(org_viewer_dep)
|
||||
):
|
||||
return await ops.get_crawl_config_out(uuid.UUID(cid), org)
|
||||
return await ops.get_crawl_config_out(cid, org)
|
||||
|
||||
@router.get(
|
||||
"/{cid}/revs",
|
||||
dependencies=[Depends(org_viewer_dep)],
|
||||
)
|
||||
async def get_crawl_config_revisions(
|
||||
cid: str, pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1
|
||||
cid: uuid.UUID, pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1
|
||||
):
|
||||
revisions, total = await ops.get_crawl_config_revs(
|
||||
uuid.UUID(cid), page_size=pageSize, page=page
|
||||
cid, page_size=pageSize, page=page
|
||||
)
|
||||
return paginated_format(revisions, total, page, pageSize)
|
||||
|
||||
@ -1000,24 +1002,24 @@ def init_crawl_config_api(
|
||||
@router.patch("/{cid}", dependencies=[Depends(org_crawl_dep)])
|
||||
async def update_crawl_config(
|
||||
update: UpdateCrawlConfig,
|
||||
cid: str,
|
||||
cid: uuid.UUID,
|
||||
org: Organization = Depends(org_crawl_dep),
|
||||
user: User = Depends(user_dep),
|
||||
):
|
||||
return await ops.update_crawl_config(uuid.UUID(cid), org, user, update)
|
||||
return await ops.update_crawl_config(cid, org, user, update)
|
||||
|
||||
@router.post("/{cid}/run")
|
||||
async def run_now(
|
||||
cid: str,
|
||||
cid: uuid.UUID,
|
||||
org: Organization = Depends(org_crawl_dep),
|
||||
user: User = Depends(user_dep),
|
||||
):
|
||||
) -> dict[str, str]:
|
||||
crawl_id = await ops.run_now(cid, org, user)
|
||||
return {"started": crawl_id}
|
||||
|
||||
@router.delete("/{cid}")
|
||||
async def make_inactive(cid: str, org: Organization = Depends(org_crawl_dep)):
|
||||
crawlconfig = await ops.get_crawl_config(uuid.UUID(cid), org.id)
|
||||
async def make_inactive(cid: uuid.UUID, org: Organization = Depends(org_crawl_dep)):
|
||||
crawlconfig = await ops.get_crawl_config(cid, org.id)
|
||||
|
||||
if not crawlconfig:
|
||||
raise HTTPException(
|
||||
|
@ -141,7 +141,7 @@ class CrawlManager(K8sAPI):
|
||||
has_scale_update
|
||||
or has_config_update
|
||||
or has_timeout_update
|
||||
or profile_filename
|
||||
or profile_filename is not None
|
||||
or has_max_crawl_size_update
|
||||
):
|
||||
await self._update_config_map(
|
||||
|
@ -7,7 +7,16 @@ from enum import Enum, IntEnum
|
||||
import os
|
||||
|
||||
from typing import Optional, List, Dict, Union, Literal, Any
|
||||
from pydantic import BaseModel, UUID4, conint, Field, HttpUrl, AnyHttpUrl, EmailStr
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
UUID4,
|
||||
conint,
|
||||
Field,
|
||||
HttpUrl,
|
||||
AnyHttpUrl,
|
||||
EmailStr,
|
||||
ConstrainedStr,
|
||||
)
|
||||
|
||||
# from fastapi_users import models as fastapi_users_models
|
||||
|
||||
@ -165,6 +174,14 @@ class ScopeType(str, Enum):
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class EmptyStr(ConstrainedStr):
|
||||
"""empty string only"""
|
||||
|
||||
min_length = 0
|
||||
max_length = 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class Seed(BaseModel):
|
||||
"""Crawl seed"""
|
||||
@ -231,7 +248,7 @@ class CrawlConfigIn(BaseModel):
|
||||
|
||||
jobType: Optional[JobType] = JobType.CUSTOM
|
||||
|
||||
profileid: Optional[str]
|
||||
profileid: Union[UUID4, EmptyStr, None]
|
||||
|
||||
autoAddCollections: Optional[List[UUID4]] = []
|
||||
tags: Optional[List[str]] = []
|
||||
@ -372,7 +389,7 @@ class UpdateCrawlConfig(BaseModel):
|
||||
|
||||
# crawl data: revision tracked
|
||||
schedule: Optional[str] = None
|
||||
profileid: Optional[str] = None
|
||||
profileid: Union[UUID4, EmptyStr, None] = None
|
||||
crawlTimeout: Optional[int] = None
|
||||
maxCrawlSize: Optional[int] = None
|
||||
scale: Optional[conint(ge=1, le=MAX_CRAWL_SCALE)] = None # type: ignore
|
||||
|
@ -431,12 +431,10 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
|
||||
|
||||
ops = OrgOps(mdb, invites)
|
||||
|
||||
async def org_dep(oid: str, user: User = Depends(user_dep)):
|
||||
org = await ops.get_org_for_user_by_id(uuid.UUID(oid), user)
|
||||
async def org_dep(oid: uuid.UUID, user: User = Depends(user_dep)):
|
||||
org = await ops.get_org_for_user_by_id(oid, user)
|
||||
if not org:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Organization '{oid}' not found"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="org_not_found")
|
||||
if not org.is_viewer(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
@ -466,8 +464,8 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
|
||||
|
||||
return org
|
||||
|
||||
async def org_public(oid: str):
|
||||
org = await ops.get_org_by_id(uuid.UUID(oid))
|
||||
async def org_public(oid: uuid.UUID):
|
||||
org = await ops.get_org_by_id(oid)
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="org_not_found")
|
||||
|
||||
@ -651,7 +649,7 @@ def init_orgs_api(app, mdb, user_manager, invites, user_dep):
|
||||
@router.post("/remove", tags=["invites"])
|
||||
async def remove_user_from_org(
|
||||
remove: RemoveFromOrg, org: Organization = Depends(org_owner_dep)
|
||||
):
|
||||
) -> dict[str, bool]:
|
||||
other_user = await user_manager.get_by_email(remove.email)
|
||||
|
||||
if org.is_owner(other_user):
|
||||
|
@ -161,7 +161,8 @@ class ProfileOps:
|
||||
|
||||
await self.crawl_manager.delete_profile_browser(browser_commit.browserid)
|
||||
|
||||
file_size = resource["bytes"]
|
||||
# backwards compatibility
|
||||
file_size = resource.get("size") or resource.get("bytes")
|
||||
|
||||
profile_file = ProfileFile(
|
||||
hash=resource["hash"],
|
||||
|
@ -8,7 +8,7 @@ from urllib.parse import unquote
|
||||
|
||||
import asyncio
|
||||
from io import BufferedReader
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Any
|
||||
from fastapi import Depends, UploadFile, File
|
||||
from fastapi import HTTPException
|
||||
from pydantic import UUID4
|
||||
@ -63,12 +63,12 @@ class UploadOps(BaseCrawlOps):
|
||||
filename: str,
|
||||
name: Optional[str],
|
||||
description: Optional[str],
|
||||
collections: Optional[List[UUID4]],
|
||||
collections: Optional[List[str]],
|
||||
tags: Optional[List[str]],
|
||||
org: Organization,
|
||||
user: User,
|
||||
replaceId: Optional[str],
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""Upload streaming file, length unknown"""
|
||||
if await self.orgs.storage_quota_reached(org.id):
|
||||
raise HTTPException(status_code=403, detail="storage_quota_reached")
|
||||
@ -122,11 +122,11 @@ class UploadOps(BaseCrawlOps):
|
||||
uploads: List[UploadFile],
|
||||
name: Optional[str],
|
||||
description: Optional[str],
|
||||
collections: Optional[List[UUID4]],
|
||||
collections: Optional[List[str]],
|
||||
tags: Optional[List[str]],
|
||||
org: Organization,
|
||||
user: User,
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""handle uploading content to uploads subdir + request subdir"""
|
||||
if await self.orgs.storage_quota_reached(org.id):
|
||||
raise HTTPException(status_code=403, detail="storage_quota_reached")
|
||||
@ -145,22 +145,31 @@ class UploadOps(BaseCrawlOps):
|
||||
files.append(file_reader.file_prep.get_crawl_file())
|
||||
|
||||
return await self._create_upload(
|
||||
files, name, description, collections, tags, id_, org, user
|
||||
files, name, description, collections, tags, str(id_), org, user
|
||||
)
|
||||
|
||||
async def _create_upload(
|
||||
self, files, name, description, collections, tags, id_, org, user
|
||||
):
|
||||
self,
|
||||
files: List[UploadFile],
|
||||
name: Optional[str],
|
||||
description: Optional[str],
|
||||
collections: Optional[List[str]],
|
||||
tags: Optional[List[str]],
|
||||
crawl_id: str,
|
||||
org: Organization,
|
||||
user: User,
|
||||
) -> dict[str, Any]:
|
||||
now = dt_now()
|
||||
# ts_now = now.strftime("%Y%m%d%H%M%S")
|
||||
# crawl_id = f"upload-{ts_now}-{str(id_)[:12]}"
|
||||
crawl_id = str(id_)
|
||||
file_size = sum(file_.size or 0 for file_ in files)
|
||||
|
||||
file_size = sum(file_.size for file_ in files)
|
||||
|
||||
collection_uuids = []
|
||||
for coll in collections:
|
||||
collection_uuids.append(uuid.UUID(coll))
|
||||
collection_uuids: List[uuid.UUID] = []
|
||||
if collections:
|
||||
try:
|
||||
for coll in collections:
|
||||
collection_uuids.append(uuid.UUID(coll))
|
||||
# pylint: disable=raise-missing-from
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="invalid_collection_id")
|
||||
|
||||
uploaded = UploadedCrawl(
|
||||
id=crawl_id,
|
||||
@ -299,7 +308,7 @@ def init_uploads_api(
|
||||
tags: Optional[str] = "",
|
||||
org: Organization = Depends(org_crawl_dep),
|
||||
user: User = Depends(user_dep),
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
name = unquote(name)
|
||||
description = unquote(description)
|
||||
colls_list = []
|
||||
@ -325,7 +334,7 @@ def init_uploads_api(
|
||||
replaceId: Optional[str] = "",
|
||||
org: Organization = Depends(org_crawl_dep),
|
||||
user: User = Depends(user_dep),
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
name = unquote(name)
|
||||
description = unquote(description)
|
||||
colls_list = []
|
||||
|
@ -6,9 +6,9 @@ import os
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import UUID4, EmailStr
|
||||
from pydantic import EmailStr
|
||||
|
||||
from fastapi import (
|
||||
Request,
|
||||
@ -178,11 +178,11 @@ class UserManager:
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_names_by_ids(self, user_ids):
|
||||
async def get_user_names_by_ids(self, user_ids: List[str]) -> dict[str, str]:
|
||||
"""return list of user names for given ids"""
|
||||
user_ids = [UUID4(id_) for id_ in user_ids]
|
||||
user_uuid_ids = [uuid.UUID(id_) for id_ in user_ids]
|
||||
cursor = self.users.find(
|
||||
{"id": {"$in": user_ids}}, projection=["id", "name", "email"]
|
||||
{"id": {"$in": user_uuid_ids}}, projection=["id", "name", "email"]
|
||||
)
|
||||
return await cursor.to_list(length=1000)
|
||||
|
||||
@ -369,7 +369,7 @@ class UserManager:
|
||||
|
||||
return user
|
||||
|
||||
async def get_by_id(self, _id: UUID4) -> Optional[User]:
|
||||
async def get_by_id(self, _id: uuid.UUID) -> Optional[User]:
|
||||
"""get user by unique id"""
|
||||
user = await self.users.find_one({"id": _id})
|
||||
|
||||
@ -409,7 +409,7 @@ class UserManager:
|
||||
raise exc
|
||||
|
||||
try:
|
||||
user_uuid = UUID4(user_id)
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise exc
|
||||
|
||||
@ -456,7 +456,7 @@ class UserManager:
|
||||
user_id = data["user_id"]
|
||||
|
||||
try:
|
||||
user_uuid = UUID4(user_id)
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -699,8 +699,8 @@ def init_users_router(current_active_user, user_manager) -> APIRouter:
|
||||
return await user_manager.format_invite(invite)
|
||||
|
||||
@users_router.get("/invite/{token}", tags=["invites"])
|
||||
async def get_invite_info(token: str, email: str):
|
||||
invite = await user_manager.invites.get_valid_invite(uuid.UUID(token), email)
|
||||
async def get_invite_info(token: uuid.UUID, email: str):
|
||||
invite = await user_manager.invites.get_valid_invite(token, email)
|
||||
return await user_manager.format_invite(invite)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
Loading…
Reference in New Issue
Block a user