browsertrix/backend/btrixcloud/basecrawls.py
Ilya Kreymer 7d694754c6
uploads api ext: (#970)
- also support collectionId filter on /all-crawls
- update tests
2023-07-09 22:12:54 -07:00

542 lines
16 KiB
Python

""" base crawl type """
import asyncio
import uuid
import os
from datetime import datetime, timedelta
from typing import Optional, Dict, List
from pydantic import BaseModel, UUID4
from fastapi import HTTPException, Depends
from .db import BaseMongoModel
from .orgs import Organization
from .pagination import PaginatedResponseModel, paginated_format, DEFAULT_PAGE_SIZE
from .storages import get_presigned_url, delete_crawl_file_object
from .users import User
from .utils import dt_now
# ============================================================================
class CrawlFile(BaseModel):
"""file from a crawl"""
filename: str
hash: str
size: int
def_storage_name: Optional[str]
presignedUrl: Optional[str]
expireAt: Optional[datetime]
# ============================================================================
class CrawlFileOut(BaseModel):
"""output for file from a crawl (conformance to Data Resource Spec)"""
name: str
path: str
hash: str
size: int
crawlId: Optional[str]
# ============================================================================
class BaseCrawl(BaseMongoModel):
"""Base Crawl object (representing crawls, uploads and manual sessions)"""
id: str
userid: UUID4
oid: UUID4
started: datetime
finished: Optional[datetime]
state: str
stats: Optional[Dict[str, int]]
files: Optional[List[CrawlFile]] = []
notes: Optional[str]
errors: Optional[List[str]] = []
collections: Optional[List[UUID4]] = []
fileSize: int = 0
fileCount: int = 0
# ============================================================================
class BaseCrawlOut(BaseMongoModel):
"""Base crawl output model"""
# pylint: disable=duplicate-code
type: Optional[str]
id: str
userid: UUID4
oid: UUID4
userName: Optional[str]
name: Optional[str]
description: Optional[str]
started: datetime
finished: Optional[datetime]
state: str
stats: Optional[Dict[str, int]]
fileSize: int = 0
fileCount: int = 0
tags: Optional[List[str]] = []
notes: Optional[str]
errors: Optional[List[str]]
collections: Optional[List[UUID4]] = []
# ============================================================================
class BaseCrawlOutWithResources(BaseCrawlOut):
"""includes resources"""
resources: Optional[List[CrawlFileOut]] = []
# ============================================================================
class UpdateCrawl(BaseModel):
"""Update crawl"""
tags: Optional[List[str]] = []
notes: Optional[str]
# ============================================================================
class DeleteCrawlList(BaseModel):
"""delete crawl list POST body"""
crawl_ids: List[str]
# ============================================================================
class BaseCrawlOps:
"""operations that apply to all crawls"""
# pylint: disable=duplicate-code, too-many-arguments, too-many-locals
def __init__(self, mdb, users, crawl_manager):
self.crawls = mdb["crawls"]
self.crawl_manager = crawl_manager
self.user_manager = users
self.presign_duration_seconds = (
int(os.environ.get("PRESIGN_DURATION_MINUTES", 60)) * 60
)
async def get_crawl_raw(
self,
crawlid: str,
org: Optional[Organization] = None,
type_: Optional[str] = None,
):
"""Get data for single crawl"""
query = {"_id": crawlid}
if org:
query["oid"] = org.id
if type_:
query["type"] = type_
res = await self.crawls.find_one(query)
if not res:
raise HTTPException(status_code=404, detail=f"Crawl not found: {crawlid}")
return res
async def get_crawl(
self,
crawlid: str,
org: Optional[Organization] = None,
type_: Optional[str] = None,
):
"""Get data for single base crawl"""
res = await self.get_crawl_raw(crawlid, org, type_)
if res.get("files"):
files = [CrawlFile(**data) for data in res["files"]]
del res["files"]
res["resources"] = await self._resolve_signed_urls(files, org, crawlid)
del res["errors"]
crawl = BaseCrawlOutWithResources.from_dict(res)
user = await self.user_manager.get(crawl.userid)
if user:
# pylint: disable=invalid-name
crawl.userName = user.name
return crawl
async def get_resource_resolved_raw_crawl(
self, crawlid: str, org: Organization, type_=None
):
"""return single base crawl with resources resolved"""
res = await self.get_crawl_raw(crawlid=crawlid, type_=type_, org=org)
files = [CrawlFile(**data) for data in res["files"]]
res["resources"] = await self._resolve_signed_urls(files, org, res["_id"])
return res
async def update_crawl(
self, crawl_id: str, org: Organization, update: UpdateCrawl, type_=None
):
"""Update existing crawl (tags and notes only for now)"""
update_values = update.dict(exclude_unset=True)
if len(update_values) == 0:
raise HTTPException(status_code=400, detail="no_update_data")
query = {"_id": crawl_id, "oid": org.id}
if type_:
query["type"] = type_
# update in db
result = await self.crawls.find_one_and_update(
query,
{"$set": update_values},
)
if not result:
raise HTTPException(status_code=404, detail="crawl_not_found")
return {"updated": True}
async def delete_crawls(
self, org: Organization, delete_list: DeleteCrawlList, type_=None
):
"""Delete a list of crawls by id for given org"""
cids_to_update = set()
size = 0
for crawl_id in delete_list.crawl_ids:
crawl = await self.get_crawl_raw(crawl_id, org)
size += await self._delete_crawl_files(crawl, org)
if crawl.get("cid"):
cids_to_update.add(crawl.get("cid"))
query = {"_id": {"$in": delete_list.crawl_ids}, "oid": org.id}
if type_:
query["type"] = type_
res = await self.crawls.delete_many(query)
return res.deleted_count, size, cids_to_update
async def _delete_crawl_files(self, crawl, org: Organization):
"""Delete files associated with crawl from storage."""
crawl = BaseCrawl.from_dict(crawl)
size = 0
for file_ in crawl.files:
size += file_.size
status_code = await delete_crawl_file_object(org, file_, self.crawl_manager)
if status_code != 204:
raise HTTPException(status_code=400, detail="file_deletion_error")
return size
async def _resolve_signed_urls(
self, files: List[CrawlFile], org: Organization, crawl_id: Optional[str] = None
):
if not files:
print("no files")
return
delta = timedelta(seconds=self.presign_duration_seconds)
updates = []
out_files = []
for file_ in files:
presigned_url = file_.presignedUrl
now = dt_now()
if not presigned_url or now >= file_.expireAt:
exp = now + delta
presigned_url = await get_presigned_url(
org, file_, self.crawl_manager, self.presign_duration_seconds
)
updates.append(
(
{"files.filename": file_.filename},
{
"$set": {
"files.$.presignedUrl": presigned_url,
"files.$.expireAt": exp,
}
},
)
)
out_files.append(
CrawlFileOut(
name=file_.filename,
path=presigned_url,
hash=file_.hash,
size=file_.size,
crawlId=crawl_id,
)
)
if updates:
asyncio.create_task(self._update_presigned(updates))
# print("presigned", out_files)
return out_files
async def _update_presigned(self, updates):
for update in updates:
await self.crawls.find_one_and_update(*update)
async def add_to_collection(
self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID, org: Organization
):
"""Add crawls to collection."""
for crawl_id in crawl_ids:
crawl_raw = await self.get_crawl_raw(crawl_id, org)
crawl_collections = crawl_raw.get("collections")
if crawl_collections and crawl_id in crawl_collections:
raise HTTPException(
status_code=400, detail="crawl_already_in_collection"
)
await self.crawls.find_one_and_update(
{"_id": crawl_id},
{"$push": {"collections": collection_id}},
)
async def remove_from_collection(
self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID
):
"""Remove crawls from collection."""
for crawl_id in crawl_ids:
await self.crawls.find_one_and_update(
{"_id": crawl_id},
{"$pull": {"collections": collection_id}},
)
async def remove_collection_from_all_crawls(self, collection_id: uuid.UUID):
"""Remove collection id from all crawls it's currently in."""
await self.crawls.update_many(
{"collections": collection_id},
{"$pull": {"collections": collection_id}},
)
# pylint: disable=too-many-branches
async def list_all_base_crawls(
self,
org: Optional[Organization] = None,
userid: uuid.UUID = None,
name: str = None,
description: str = None,
collection_id: str = None,
states: Optional[List[str]] = None,
page_size: int = DEFAULT_PAGE_SIZE,
page: int = 1,
sort_by: str = None,
sort_direction: int = -1,
cls_type: type[BaseCrawlOut] = BaseCrawlOut,
type_=None,
):
"""List crawls of all types from the db"""
# Zero-index page for query
page = page - 1
skip = page * page_size
oid = org.id if org else None
query = {}
if type_:
query["type"] = type_
if oid:
query["oid"] = oid
if userid:
query["userid"] = userid
if states:
# validated_states = [value for value in state if value in ALL_CRAWL_STATES]
query["state"] = {"$in": states}
aggregate = [{"$match": query}]
if name:
aggregate.extend([{"$match": {"name": name}}])
if description:
aggregate.extend([{"$match": {"description": description}}])
if collection_id:
aggregate.extend([{"$match": {"collections": {"$in": [collection_id]}}}])
if sort_by:
if sort_by not in ("started", "finished"):
raise HTTPException(status_code=400, detail="invalid_sort_by")
if sort_direction not in (1, -1):
raise HTTPException(status_code=400, detail="invalid_sort_direction")
aggregate.extend([{"$sort": {sort_by: sort_direction}}])
aggregate.extend(
[
{
"$lookup": {
"from": "users",
"localField": "userid",
"foreignField": "id",
"as": "userName",
},
},
{"$set": {"userName": {"$arrayElemAt": ["$userName.name", 0]}}},
{
"$facet": {
"items": [
{"$skip": skip},
{"$limit": page_size},
],
"total": [{"$count": "count"}],
}
},
]
)
# Get total
cursor = self.crawls.aggregate(aggregate)
results = await cursor.to_list(length=1)
result = results[0]
items = result["items"]
try:
total = int(result["total"][0]["count"])
except (IndexError, ValueError):
total = 0
crawls = []
for res in items:
files = None
if res.get("files"):
files = [CrawlFile(**data) for data in res["files"]]
del res["files"]
crawl = cls_type.from_dict(res)
if hasattr(crawl, "resources"):
# pylint: disable=attribute-defined-outside-init
crawl.resources = await self._resolve_signed_urls(files, org, crawl.id)
crawls.append(crawl)
return crawls, total
async def delete_crawls_all_types(
self, delete_list: DeleteCrawlList, org: Optional[Organization] = None
):
"""Delete uploaded crawls"""
deleted_count, _, _ = await self.delete_crawls(org, delete_list)
if deleted_count < 1:
raise HTTPException(status_code=404, detail="crawl_not_found")
return {"deleted": True}
# ============================================================================
def init_base_crawls_api(app, mdb, users, crawl_manager, orgs, user_dep):
"""base crawls api"""
# pylint: disable=invalid-name, duplicate-code, too-many-arguments
ops = BaseCrawlOps(mdb, users, crawl_manager)
org_viewer_dep = orgs.org_viewer_dep
org_crawl_dep = orgs.org_crawl_dep
@app.get(
"/orgs/{oid}/all-crawls",
tags=["all-crawls"],
response_model=PaginatedResponseModel,
)
async def list_all_base_crawls(
org: Organization = Depends(org_viewer_dep),
pageSize: int = DEFAULT_PAGE_SIZE,
page: int = 1,
userid: Optional[UUID4] = None,
name: Optional[str] = None,
state: Optional[str] = None,
description: Optional[str] = None,
collectionId: Optional[UUID4] = None,
sortBy: Optional[str] = "finished",
sortDirection: Optional[int] = -1,
):
states = state.split(",") if state else None
crawls, total = await ops.list_all_base_crawls(
org,
userid=userid,
name=name,
description=description,
collection_id=collectionId,
states=states,
page_size=pageSize,
page=page,
sort_by=sortBy,
sort_direction=sortDirection,
)
return paginated_format(crawls, total, page, pageSize)
@app.get(
"/orgs/{oid}/all-crawls/{crawlid}",
tags=["all-crawls"],
response_model=BaseCrawlOutWithResources,
)
async def get_base_crawl(crawlid: str, org: Organization = Depends(org_crawl_dep)):
res = await ops.get_resource_resolved_raw_crawl(crawlid, org)
return BaseCrawlOutWithResources.from_dict(res)
@app.get(
"/orgs/all/all-crawls/{crawl_id}/replay.json",
tags=["all-crawls"],
response_model=BaseCrawlOutWithResources,
)
async def get_base_crawl_admin(crawl_id, user: User = Depends(user_dep)):
if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed")
return await ops.get_crawl(crawl_id, None)
@app.get(
"/orgs/{oid}/all-crawls/{crawl_id}/replay.json",
tags=["all-crawls"],
response_model=BaseCrawlOutWithResources,
)
async def get_crawl(crawl_id, org: Organization = Depends(org_viewer_dep)):
return await ops.get_crawl(crawl_id, org)
@app.post("/orgs/{oid}/all-crawls/delete", tags=["all-crawls"])
async def delete_crawls_all_types(
delete_list: DeleteCrawlList,
org: Organization = Depends(org_crawl_dep),
):
return await ops.delete_crawls_all_types(delete_list, org)