browsertrix/backend/btrixcloud/colls.py
Ilya Kreymer 6384d8b5f1
Additional Type Hints / Type Fix Pass (#1320)
This PR adds more type safety to the backend codebase:
- All ops classes calls should be type checked
- Avoiding circular references with TYPE_CHECKING conditional
- Consistent UUID usage: uuid.UUID / UUID4 with just UUID
- Crawl states moved to models, made into lists
- Additional typing added as needed, fixed a few type related errors
- CrawlOps / UploadOps / BaseCrawlOps now all have same param init order
to simplify changes
2023-10-30 12:59:24 -04:00

532 lines
18 KiB
Python

"""
Collections API
"""
from collections import Counter
from datetime import datetime
from uuid import UUID, uuid4
from typing import Optional, List, TYPE_CHECKING, cast
import asyncio
import pymongo
from fastapi import Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
from .models import (
Collection,
CollIn,
CollOut,
CollIdName,
UpdateColl,
AddRemoveCrawlList,
CrawlOutWithResources,
Organization,
PaginatedResponse,
SUCCESSFUL_STATES,
)
if TYPE_CHECKING:
from .orgs import OrgOps
from .storages import StorageOps
from .webhooks import EventWebhookOps
from .crawls import CrawlOps
else:
OrgOps = StorageOps = EventWebhookOps = CrawlOps = object
# ============================================================================
class CollectionOps:
"""ops for working with named collections of crawls"""
# pylint: disable=too-many-arguments
orgs: OrgOps
storage_ops: StorageOps
event_webhook_ops: EventWebhookOps
crawl_ops: CrawlOps
def __init__(self, mdb, storage_ops, orgs, event_webhook_ops):
self.collections = mdb["collections"]
self.crawls = mdb["crawls"]
self.crawl_configs = mdb["crawl_configs"]
self.crawl_ops = cast(CrawlOps, None)
self.orgs = orgs
self.storage_ops = storage_ops
self.event_webhook_ops = event_webhook_ops
def set_crawl_ops(self, ops):
"""set crawl ops"""
self.crawl_ops = ops
async def init_index(self):
"""init lookup index"""
await self.collections.create_index(
[("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], unique=True
)
await self.collections.create_index(
[("oid", pymongo.ASCENDING), ("description", pymongo.ASCENDING)]
)
async def add_collection(self, oid: UUID, coll_in: CollIn):
"""Add new collection"""
crawl_ids = coll_in.crawlIds if coll_in.crawlIds else []
coll_id = uuid4()
modified = datetime.utcnow().replace(microsecond=0, tzinfo=None)
coll = Collection(
id=coll_id,
oid=oid,
name=coll_in.name,
description=coll_in.description,
modified=modified,
isPublic=coll_in.isPublic,
)
try:
await self.collections.insert_one(coll.to_dict())
org = await self.orgs.get_org_by_id(oid)
if crawl_ids:
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
await self.update_collection_counts_and_tags(coll_id)
asyncio.create_task(
self.event_webhook_ops.create_added_to_collection_notification(
crawl_ids, coll_id, org
)
)
return {"added": True, "id": coll_id, "name": coll.name}
except pymongo.errors.DuplicateKeyError:
# pylint: disable=raise-missing-from
raise HTTPException(status_code=400, detail="collection_name_taken")
async def update_collection(
self, coll_id: UUID, org: Organization, update: UpdateColl
):
"""Update collection"""
query = update.dict(exclude_unset=True)
if len(query) == 0:
raise HTTPException(status_code=400, detail="no_update_data")
query["modified"] = datetime.utcnow().replace(microsecond=0, tzinfo=None)
try:
result = await self.collections.find_one_and_update(
{"_id": coll_id, "oid": org.id},
{"$set": query},
return_document=pymongo.ReturnDocument.AFTER,
)
except pymongo.errors.DuplicateKeyError:
# pylint: disable=raise-missing-from
raise HTTPException(status_code=400, detail="collection_name_taken")
if not result:
raise HTTPException(status_code=404, detail="collection_not_found")
return {"updated": True}
async def add_crawls_to_collection(
self, coll_id: UUID, crawl_ids: List[str], org: Organization
) -> CollOut:
"""Add crawls to collection"""
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
modified = datetime.utcnow().replace(microsecond=0, tzinfo=None)
result = await self.collections.find_one_and_update(
{"_id": coll_id},
{"$set": {"modified": modified}},
return_document=pymongo.ReturnDocument.AFTER,
)
if not result:
raise HTTPException(status_code=404, detail="collection_not_found")
await self.update_collection_counts_and_tags(coll_id)
asyncio.create_task(
self.event_webhook_ops.create_added_to_collection_notification(
crawl_ids, coll_id, org
)
)
return await self.get_collection(coll_id, org)
async def remove_crawls_from_collection(
self, coll_id: UUID, crawl_ids: List[str], org: Organization
) -> CollOut:
"""Remove crawls from collection"""
await self.crawl_ops.remove_from_collection(crawl_ids, coll_id)
modified = datetime.utcnow().replace(microsecond=0, tzinfo=None)
result = await self.collections.find_one_and_update(
{"_id": coll_id},
{"$set": {"modified": modified}},
return_document=pymongo.ReturnDocument.AFTER,
)
if not result:
raise HTTPException(status_code=404, detail="collection_not_found")
await self.update_collection_counts_and_tags(coll_id)
asyncio.create_task(
self.event_webhook_ops.create_removed_from_collection_notification(
crawl_ids, coll_id, org
)
)
return await self.get_collection(coll_id, org)
async def get_collection(
self, coll_id: UUID, org: Organization, resources=False, public_only=False
) -> CollOut:
"""Get collection by id"""
query: dict[str, object] = {"_id": coll_id}
if public_only:
query["isPublic"] = True
result = await self.collections.find_one(query)
if not result:
raise HTTPException(status_code=404, detail="collection_not_found")
if resources:
result["resources"] = await self.get_collection_crawl_resources(
coll_id, org
)
return CollOut.from_dict(result)
async def list_collections(
self,
oid: UUID,
page_size: int = DEFAULT_PAGE_SIZE,
page: int = 1,
sort_by: Optional[str] = None,
sort_direction: int = 1,
name: Optional[str] = None,
name_prefix: Optional[str] = None,
):
"""List all collections for org"""
# pylint: disable=too-many-locals, duplicate-code
# Zero-index page for query
page = page - 1
skip = page * page_size
match_query: dict[str, object] = {"oid": oid}
if name:
match_query["name"] = name
elif name_prefix:
regex_pattern = f"^{name_prefix}"
match_query["name"] = {"$regex": regex_pattern, "$options": "i"}
aggregate = [{"$match": match_query}]
if sort_by:
if sort_by not in ("modified", "name", "description", "totalSize"):
raise HTTPException(status_code=400, detail="invalid_sort_by")
if sort_direction not in (1, -1):
raise HTTPException(status_code=400, detail="invalid_sort_direction")
aggregate.extend([{"$sort": {sort_by: sort_direction}}])
aggregate.extend(
[
{
"$facet": {
"items": [
{"$skip": skip},
{"$limit": page_size},
],
"total": [{"$count": "count"}],
}
},
]
)
cursor = self.collections.aggregate(
aggregate, collation=pymongo.collation.Collation(locale="en")
)
results = await cursor.to_list(length=1)
result = results[0]
items = result["items"]
try:
total = int(result["total"][0]["count"])
except (IndexError, ValueError):
total = 0
collections = [CollOut.from_dict(res) for res in items]
return collections, total
async def get_collection_crawl_resources(self, coll_id: UUID, org: Organization):
"""Return pre-signed resources for all collection crawl files."""
coll = await self.get_collection(coll_id, org)
if not coll:
raise HTTPException(status_code=404, detail="collection_not_found")
all_files = []
crawls, _ = await self.crawl_ops.list_all_base_crawls(
collection_id=coll_id,
states=list(SUCCESSFUL_STATES),
page_size=10_000,
cls_type=CrawlOutWithResources,
)
for crawl in crawls:
if crawl.resources:
all_files.extend(crawl.resources)
return all_files
async def get_collection_names(self, uuids: List[UUID]):
"""return object of {_id, names} given list of collection ids"""
cursor = self.collections.find(
{"_id": {"$in": uuids}}, projection=["_id", "name"]
)
names = await cursor.to_list(length=1000)
names = [
CollIdName(id=namedata["_id"], name=namedata["name"]) for namedata in names
]
return names
async def get_collection_search_values(self, org: Organization):
"""Return list of collection names"""
names = await self.collections.distinct("name", {"oid": org.id})
# Remove empty strings
names = [name for name in names if name]
return {"names": names}
async def delete_collection(self, coll_id: UUID, org: Organization):
"""Delete collection and remove from associated crawls."""
await self.crawl_ops.remove_collection_from_all_crawls(coll_id)
result = await self.collections.delete_one({"_id": coll_id, "oid": org.id})
if result.deleted_count < 1:
raise HTTPException(status_code=404, detail="collection_not_found")
return {"success": True}
async def download_collection(self, coll_id: UUID, org: Organization):
"""Download all WACZs in collection as streaming nested WACZ"""
coll = await self.get_collection(coll_id, org, resources=True)
resp = await self.storage_ops.download_streaming_wacz(org, coll.resources)
headers = {"Content-Disposition": f'attachment; filename="{coll.name}.wacz"'}
return StreamingResponse(
resp, headers=headers, media_type="application/wacz+zip"
)
async def update_collection_counts_and_tags(self, collection_id: UUID):
"""Set current crawl info in config when crawl begins"""
crawl_count = 0
page_count = 0
total_size = 0
tags = []
async for crawl in self.crawls.find({"collectionIds": collection_id}):
if crawl["state"] not in SUCCESSFUL_STATES:
continue
crawl_count += 1
files = crawl.get("files", [])
for file in files:
total_size += file.get("size", 0)
if crawl.get("stats"):
page_count += crawl.get("stats", {}).get("done", 0)
if crawl.get("tags"):
tags.extend(crawl.get("tags"))
sorted_tags = [tag for tag, count in Counter(tags).most_common()]
await self.collections.find_one_and_update(
{"_id": collection_id},
{
"$set": {
"crawlCount": crawl_count,
"pageCount": page_count,
"totalSize": total_size,
"tags": sorted_tags,
}
},
)
async def update_crawl_collections(self, crawl_id: str):
"""Update counts and tags for all collections in crawl"""
crawl = await self.crawls.find_one({"_id": crawl_id})
crawl_coll_ids = crawl.get("collectionIds")
for collection_id in crawl_coll_ids:
await self.update_collection_counts_and_tags(collection_id)
async def add_successful_crawl_to_collections(self, crawl_id: str, cid: UUID):
"""Add successful crawl to its auto-add collections."""
workflow = await self.crawl_configs.find_one({"_id": cid})
auto_add_collections = workflow.get("autoAddCollections")
if auto_add_collections:
await self.crawls.find_one_and_update(
{"_id": crawl_id},
{"$set": {"collectionIds": auto_add_collections}},
)
await self.update_crawl_collections(crawl_id)
# ============================================================================
# pylint: disable=too-many-locals
def init_collections_api(app, mdb, orgs, storage_ops, event_webhook_ops):
"""init collections api"""
# pylint: disable=invalid-name, unused-argument, too-many-arguments
colls = CollectionOps(mdb, storage_ops, orgs, event_webhook_ops)
org_crawl_dep = orgs.org_crawl_dep
org_viewer_dep = orgs.org_viewer_dep
org_public = orgs.org_public
@app.post("/orgs/{oid}/collections", tags=["collections"])
async def add_collection(
new_coll: CollIn, org: Organization = Depends(org_crawl_dep)
):
return await colls.add_collection(org.id, new_coll)
@app.get(
"/orgs/{oid}/collections",
tags=["collections"],
response_model=PaginatedResponse,
)
async def list_collection_all(
org: Organization = Depends(org_viewer_dep),
pageSize: int = DEFAULT_PAGE_SIZE,
page: int = 1,
sortBy: Optional[str] = None,
sortDirection: int = 1,
name: Optional[str] = None,
namePrefix: Optional[str] = None,
):
collections, total = await colls.list_collections(
org.id,
page_size=pageSize,
page=page,
sort_by=sortBy,
sort_direction=sortDirection,
name=name,
name_prefix=namePrefix,
)
return paginated_format(collections, total, page, pageSize)
@app.get(
"/orgs/{oid}/collections/$all",
tags=["collections"],
)
async def get_collection_all(org: Organization = Depends(org_viewer_dep)):
results = {}
try:
all_collections, _ = await colls.list_collections(org.id, page_size=10_000)
for collection in all_collections:
results[collection.name] = await colls.get_collection_crawl_resources(
collection.id, org
)
except Exception as exc:
# pylint: disable=raise-missing-from
raise HTTPException(
status_code=400, detail="Error Listing All Crawled Files: " + str(exc)
)
return results
@app.get("/orgs/{oid}/collections/search-values", tags=["collections"])
async def get_collection_search_values(
org: Organization = Depends(org_viewer_dep),
):
return await colls.get_collection_search_values(org)
@app.get(
"/orgs/{oid}/collections/{coll_id}",
tags=["collections"],
response_model=CollOut,
)
async def get_collection(
coll_id: UUID, org: Organization = Depends(org_viewer_dep)
) -> CollOut:
return await colls.get_collection(coll_id, org)
@app.get("/orgs/{oid}/collections/{coll_id}/replay.json", tags=["collections"])
async def get_collection_replay(
coll_id: UUID, org: Organization = Depends(org_viewer_dep)
) -> CollOut:
return await colls.get_collection(coll_id, org, resources=True)
@app.get(
"/orgs/{oid}/collections/{coll_id}/public/replay.json", tags=["collections"]
)
async def get_collection_public_replay(
response: Response,
coll_id: UUID,
org: Organization = Depends(org_public),
) -> CollOut:
coll = await colls.get_collection(
coll_id, org, resources=True, public_only=True
)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return coll
@app.options(
"/orgs/{oid}/collections/{coll_id}/public/replay.json", tags=["collections"]
)
async def get_replay_preflight(response: Response):
response.headers["Access-Control-Allow-Methods"] = "GET, HEAD, OPTIONS"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return {}
@app.patch("/orgs/{oid}/collections/{coll_id}", tags=["collections"])
async def update_collection(
coll_id: UUID,
update: UpdateColl,
org: Organization = Depends(org_crawl_dep),
):
return await colls.update_collection(coll_id, org, update)
@app.post(
"/orgs/{oid}/collections/{coll_id}/add",
tags=["collections"],
response_model=CollOut,
)
async def add_crawl_to_collection(
crawlList: AddRemoveCrawlList,
coll_id: UUID,
org: Organization = Depends(org_crawl_dep),
) -> CollOut:
return await colls.add_crawls_to_collection(coll_id, crawlList.crawlIds, org)
@app.post(
"/orgs/{oid}/collections/{coll_id}/remove",
tags=["collections"],
response_model=CollOut,
)
async def remove_crawl_from_collection(
crawlList: AddRemoveCrawlList,
coll_id: UUID,
org: Organization = Depends(org_crawl_dep),
) -> CollOut:
return await colls.remove_crawls_from_collection(
coll_id, crawlList.crawlIds, org
)
@app.delete(
"/orgs/{oid}/collections/{coll_id}",
tags=["collections"],
)
async def delete_collection(
coll_id: UUID, org: Organization = Depends(org_crawl_dep)
):
return await colls.delete_collection(coll_id, org)
@app.get("/orgs/{oid}/collections/{coll_id}/download", tags=["collections"])
async def download_collection(
coll_id: UUID, org: Organization = Depends(org_viewer_dep)
):
return await colls.download_collection(coll_id, org)
return colls