This PR adds more type safety to the backend codebase: - All ops classes calls should be type checked - Avoiding circular references with TYPE_CHECKING conditional - Consistent UUID usage: uuid.UUID / UUID4 with just UUID - Crawl states moved to models, made into lists - Additional typing added as needed, fixed a few type related errors - CrawlOps / UploadOps / BaseCrawlOps now all have same param init order to simplify changes
532 lines
18 KiB
Python
532 lines
18 KiB
Python
"""
|
|
Collections API
|
|
"""
|
|
from collections import Counter
|
|
from datetime import datetime
|
|
from uuid import UUID, uuid4
|
|
from typing import Optional, List, TYPE_CHECKING, cast
|
|
|
|
import asyncio
|
|
import pymongo
|
|
from fastapi import Depends, HTTPException, Response
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
|
|
from .models import (
|
|
Collection,
|
|
CollIn,
|
|
CollOut,
|
|
CollIdName,
|
|
UpdateColl,
|
|
AddRemoveCrawlList,
|
|
CrawlOutWithResources,
|
|
Organization,
|
|
PaginatedResponse,
|
|
SUCCESSFUL_STATES,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from .orgs import OrgOps
|
|
from .storages import StorageOps
|
|
from .webhooks import EventWebhookOps
|
|
from .crawls import CrawlOps
|
|
else:
|
|
OrgOps = StorageOps = EventWebhookOps = CrawlOps = object
|
|
|
|
|
|
# ============================================================================
|
|
class CollectionOps:
|
|
"""ops for working with named collections of crawls"""
|
|
|
|
# pylint: disable=too-many-arguments
|
|
|
|
orgs: OrgOps
|
|
storage_ops: StorageOps
|
|
event_webhook_ops: EventWebhookOps
|
|
crawl_ops: CrawlOps
|
|
|
|
def __init__(self, mdb, storage_ops, orgs, event_webhook_ops):
|
|
self.collections = mdb["collections"]
|
|
self.crawls = mdb["crawls"]
|
|
self.crawl_configs = mdb["crawl_configs"]
|
|
self.crawl_ops = cast(CrawlOps, None)
|
|
|
|
self.orgs = orgs
|
|
self.storage_ops = storage_ops
|
|
self.event_webhook_ops = event_webhook_ops
|
|
|
|
def set_crawl_ops(self, ops):
|
|
"""set crawl ops"""
|
|
self.crawl_ops = ops
|
|
|
|
async def init_index(self):
|
|
"""init lookup index"""
|
|
await self.collections.create_index(
|
|
[("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], unique=True
|
|
)
|
|
|
|
await self.collections.create_index(
|
|
[("oid", pymongo.ASCENDING), ("description", pymongo.ASCENDING)]
|
|
)
|
|
|
|
async def add_collection(self, oid: UUID, coll_in: CollIn):
|
|
"""Add new collection"""
|
|
crawl_ids = coll_in.crawlIds if coll_in.crawlIds else []
|
|
coll_id = uuid4()
|
|
modified = datetime.utcnow().replace(microsecond=0, tzinfo=None)
|
|
|
|
coll = Collection(
|
|
id=coll_id,
|
|
oid=oid,
|
|
name=coll_in.name,
|
|
description=coll_in.description,
|
|
modified=modified,
|
|
isPublic=coll_in.isPublic,
|
|
)
|
|
try:
|
|
await self.collections.insert_one(coll.to_dict())
|
|
org = await self.orgs.get_org_by_id(oid)
|
|
if crawl_ids:
|
|
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
|
|
await self.update_collection_counts_and_tags(coll_id)
|
|
asyncio.create_task(
|
|
self.event_webhook_ops.create_added_to_collection_notification(
|
|
crawl_ids, coll_id, org
|
|
)
|
|
)
|
|
|
|
return {"added": True, "id": coll_id, "name": coll.name}
|
|
except pymongo.errors.DuplicateKeyError:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(status_code=400, detail="collection_name_taken")
|
|
|
|
async def update_collection(
|
|
self, coll_id: UUID, org: Organization, update: UpdateColl
|
|
):
|
|
"""Update collection"""
|
|
query = update.dict(exclude_unset=True)
|
|
|
|
if len(query) == 0:
|
|
raise HTTPException(status_code=400, detail="no_update_data")
|
|
|
|
query["modified"] = datetime.utcnow().replace(microsecond=0, tzinfo=None)
|
|
|
|
try:
|
|
result = await self.collections.find_one_and_update(
|
|
{"_id": coll_id, "oid": org.id},
|
|
{"$set": query},
|
|
return_document=pymongo.ReturnDocument.AFTER,
|
|
)
|
|
except pymongo.errors.DuplicateKeyError:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(status_code=400, detail="collection_name_taken")
|
|
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
return {"updated": True}
|
|
|
|
async def add_crawls_to_collection(
|
|
self, coll_id: UUID, crawl_ids: List[str], org: Organization
|
|
) -> CollOut:
|
|
"""Add crawls to collection"""
|
|
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
|
|
|
|
modified = datetime.utcnow().replace(microsecond=0, tzinfo=None)
|
|
result = await self.collections.find_one_and_update(
|
|
{"_id": coll_id},
|
|
{"$set": {"modified": modified}},
|
|
return_document=pymongo.ReturnDocument.AFTER,
|
|
)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
await self.update_collection_counts_and_tags(coll_id)
|
|
|
|
asyncio.create_task(
|
|
self.event_webhook_ops.create_added_to_collection_notification(
|
|
crawl_ids, coll_id, org
|
|
)
|
|
)
|
|
|
|
return await self.get_collection(coll_id, org)
|
|
|
|
async def remove_crawls_from_collection(
|
|
self, coll_id: UUID, crawl_ids: List[str], org: Organization
|
|
) -> CollOut:
|
|
"""Remove crawls from collection"""
|
|
await self.crawl_ops.remove_from_collection(crawl_ids, coll_id)
|
|
modified = datetime.utcnow().replace(microsecond=0, tzinfo=None)
|
|
result = await self.collections.find_one_and_update(
|
|
{"_id": coll_id},
|
|
{"$set": {"modified": modified}},
|
|
return_document=pymongo.ReturnDocument.AFTER,
|
|
)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
await self.update_collection_counts_and_tags(coll_id)
|
|
|
|
asyncio.create_task(
|
|
self.event_webhook_ops.create_removed_from_collection_notification(
|
|
crawl_ids, coll_id, org
|
|
)
|
|
)
|
|
|
|
return await self.get_collection(coll_id, org)
|
|
|
|
async def get_collection(
|
|
self, coll_id: UUID, org: Organization, resources=False, public_only=False
|
|
) -> CollOut:
|
|
"""Get collection by id"""
|
|
query: dict[str, object] = {"_id": coll_id}
|
|
if public_only:
|
|
query["isPublic"] = True
|
|
|
|
result = await self.collections.find_one(query)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
if resources:
|
|
result["resources"] = await self.get_collection_crawl_resources(
|
|
coll_id, org
|
|
)
|
|
return CollOut.from_dict(result)
|
|
|
|
async def list_collections(
|
|
self,
|
|
oid: UUID,
|
|
page_size: int = DEFAULT_PAGE_SIZE,
|
|
page: int = 1,
|
|
sort_by: Optional[str] = None,
|
|
sort_direction: int = 1,
|
|
name: Optional[str] = None,
|
|
name_prefix: Optional[str] = None,
|
|
):
|
|
"""List all collections for org"""
|
|
# pylint: disable=too-many-locals, duplicate-code
|
|
# Zero-index page for query
|
|
page = page - 1
|
|
skip = page * page_size
|
|
|
|
match_query: dict[str, object] = {"oid": oid}
|
|
|
|
if name:
|
|
match_query["name"] = name
|
|
|
|
elif name_prefix:
|
|
regex_pattern = f"^{name_prefix}"
|
|
match_query["name"] = {"$regex": regex_pattern, "$options": "i"}
|
|
|
|
aggregate = [{"$match": match_query}]
|
|
|
|
if sort_by:
|
|
if sort_by not in ("modified", "name", "description", "totalSize"):
|
|
raise HTTPException(status_code=400, detail="invalid_sort_by")
|
|
if sort_direction not in (1, -1):
|
|
raise HTTPException(status_code=400, detail="invalid_sort_direction")
|
|
|
|
aggregate.extend([{"$sort": {sort_by: sort_direction}}])
|
|
|
|
aggregate.extend(
|
|
[
|
|
{
|
|
"$facet": {
|
|
"items": [
|
|
{"$skip": skip},
|
|
{"$limit": page_size},
|
|
],
|
|
"total": [{"$count": "count"}],
|
|
}
|
|
},
|
|
]
|
|
)
|
|
|
|
cursor = self.collections.aggregate(
|
|
aggregate, collation=pymongo.collation.Collation(locale="en")
|
|
)
|
|
results = await cursor.to_list(length=1)
|
|
result = results[0]
|
|
items = result["items"]
|
|
|
|
try:
|
|
total = int(result["total"][0]["count"])
|
|
except (IndexError, ValueError):
|
|
total = 0
|
|
|
|
collections = [CollOut.from_dict(res) for res in items]
|
|
|
|
return collections, total
|
|
|
|
async def get_collection_crawl_resources(self, coll_id: UUID, org: Organization):
|
|
"""Return pre-signed resources for all collection crawl files."""
|
|
coll = await self.get_collection(coll_id, org)
|
|
if not coll:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
all_files = []
|
|
|
|
crawls, _ = await self.crawl_ops.list_all_base_crawls(
|
|
collection_id=coll_id,
|
|
states=list(SUCCESSFUL_STATES),
|
|
page_size=10_000,
|
|
cls_type=CrawlOutWithResources,
|
|
)
|
|
|
|
for crawl in crawls:
|
|
if crawl.resources:
|
|
all_files.extend(crawl.resources)
|
|
|
|
return all_files
|
|
|
|
async def get_collection_names(self, uuids: List[UUID]):
|
|
"""return object of {_id, names} given list of collection ids"""
|
|
cursor = self.collections.find(
|
|
{"_id": {"$in": uuids}}, projection=["_id", "name"]
|
|
)
|
|
names = await cursor.to_list(length=1000)
|
|
names = [
|
|
CollIdName(id=namedata["_id"], name=namedata["name"]) for namedata in names
|
|
]
|
|
return names
|
|
|
|
async def get_collection_search_values(self, org: Organization):
|
|
"""Return list of collection names"""
|
|
names = await self.collections.distinct("name", {"oid": org.id})
|
|
# Remove empty strings
|
|
names = [name for name in names if name]
|
|
return {"names": names}
|
|
|
|
async def delete_collection(self, coll_id: UUID, org: Organization):
|
|
"""Delete collection and remove from associated crawls."""
|
|
await self.crawl_ops.remove_collection_from_all_crawls(coll_id)
|
|
|
|
result = await self.collections.delete_one({"_id": coll_id, "oid": org.id})
|
|
if result.deleted_count < 1:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
return {"success": True}
|
|
|
|
async def download_collection(self, coll_id: UUID, org: Organization):
|
|
"""Download all WACZs in collection as streaming nested WACZ"""
|
|
coll = await self.get_collection(coll_id, org, resources=True)
|
|
|
|
resp = await self.storage_ops.download_streaming_wacz(org, coll.resources)
|
|
|
|
headers = {"Content-Disposition": f'attachment; filename="{coll.name}.wacz"'}
|
|
return StreamingResponse(
|
|
resp, headers=headers, media_type="application/wacz+zip"
|
|
)
|
|
|
|
async def update_collection_counts_and_tags(self, collection_id: UUID):
|
|
"""Set current crawl info in config when crawl begins"""
|
|
crawl_count = 0
|
|
page_count = 0
|
|
total_size = 0
|
|
tags = []
|
|
|
|
async for crawl in self.crawls.find({"collectionIds": collection_id}):
|
|
if crawl["state"] not in SUCCESSFUL_STATES:
|
|
continue
|
|
crawl_count += 1
|
|
files = crawl.get("files", [])
|
|
for file in files:
|
|
total_size += file.get("size", 0)
|
|
if crawl.get("stats"):
|
|
page_count += crawl.get("stats", {}).get("done", 0)
|
|
if crawl.get("tags"):
|
|
tags.extend(crawl.get("tags"))
|
|
|
|
sorted_tags = [tag for tag, count in Counter(tags).most_common()]
|
|
|
|
await self.collections.find_one_and_update(
|
|
{"_id": collection_id},
|
|
{
|
|
"$set": {
|
|
"crawlCount": crawl_count,
|
|
"pageCount": page_count,
|
|
"totalSize": total_size,
|
|
"tags": sorted_tags,
|
|
}
|
|
},
|
|
)
|
|
|
|
async def update_crawl_collections(self, crawl_id: str):
|
|
"""Update counts and tags for all collections in crawl"""
|
|
crawl = await self.crawls.find_one({"_id": crawl_id})
|
|
crawl_coll_ids = crawl.get("collectionIds")
|
|
for collection_id in crawl_coll_ids:
|
|
await self.update_collection_counts_and_tags(collection_id)
|
|
|
|
async def add_successful_crawl_to_collections(self, crawl_id: str, cid: UUID):
|
|
"""Add successful crawl to its auto-add collections."""
|
|
workflow = await self.crawl_configs.find_one({"_id": cid})
|
|
auto_add_collections = workflow.get("autoAddCollections")
|
|
if auto_add_collections:
|
|
await self.crawls.find_one_and_update(
|
|
{"_id": crawl_id},
|
|
{"$set": {"collectionIds": auto_add_collections}},
|
|
)
|
|
await self.update_crawl_collections(crawl_id)
|
|
|
|
|
|
# ============================================================================
|
|
# pylint: disable=too-many-locals
|
|
def init_collections_api(app, mdb, orgs, storage_ops, event_webhook_ops):
|
|
"""init collections api"""
|
|
# pylint: disable=invalid-name, unused-argument, too-many-arguments
|
|
|
|
colls = CollectionOps(mdb, storage_ops, orgs, event_webhook_ops)
|
|
|
|
org_crawl_dep = orgs.org_crawl_dep
|
|
org_viewer_dep = orgs.org_viewer_dep
|
|
org_public = orgs.org_public
|
|
|
|
@app.post("/orgs/{oid}/collections", tags=["collections"])
|
|
async def add_collection(
|
|
new_coll: CollIn, org: Organization = Depends(org_crawl_dep)
|
|
):
|
|
return await colls.add_collection(org.id, new_coll)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections",
|
|
tags=["collections"],
|
|
response_model=PaginatedResponse,
|
|
)
|
|
async def list_collection_all(
|
|
org: Organization = Depends(org_viewer_dep),
|
|
pageSize: int = DEFAULT_PAGE_SIZE,
|
|
page: int = 1,
|
|
sortBy: Optional[str] = None,
|
|
sortDirection: int = 1,
|
|
name: Optional[str] = None,
|
|
namePrefix: Optional[str] = None,
|
|
):
|
|
collections, total = await colls.list_collections(
|
|
org.id,
|
|
page_size=pageSize,
|
|
page=page,
|
|
sort_by=sortBy,
|
|
sort_direction=sortDirection,
|
|
name=name,
|
|
name_prefix=namePrefix,
|
|
)
|
|
return paginated_format(collections, total, page, pageSize)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/$all",
|
|
tags=["collections"],
|
|
)
|
|
async def get_collection_all(org: Organization = Depends(org_viewer_dep)):
|
|
results = {}
|
|
try:
|
|
all_collections, _ = await colls.list_collections(org.id, page_size=10_000)
|
|
for collection in all_collections:
|
|
results[collection.name] = await colls.get_collection_crawl_resources(
|
|
collection.id, org
|
|
)
|
|
except Exception as exc:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(
|
|
status_code=400, detail="Error Listing All Crawled Files: " + str(exc)
|
|
)
|
|
|
|
return results
|
|
|
|
@app.get("/orgs/{oid}/collections/search-values", tags=["collections"])
|
|
async def get_collection_search_values(
|
|
org: Organization = Depends(org_viewer_dep),
|
|
):
|
|
return await colls.get_collection_search_values(org)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/{coll_id}",
|
|
tags=["collections"],
|
|
response_model=CollOut,
|
|
)
|
|
async def get_collection(
|
|
coll_id: UUID, org: Organization = Depends(org_viewer_dep)
|
|
) -> CollOut:
|
|
return await colls.get_collection(coll_id, org)
|
|
|
|
@app.get("/orgs/{oid}/collections/{coll_id}/replay.json", tags=["collections"])
|
|
async def get_collection_replay(
|
|
coll_id: UUID, org: Organization = Depends(org_viewer_dep)
|
|
) -> CollOut:
|
|
return await colls.get_collection(coll_id, org, resources=True)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/{coll_id}/public/replay.json", tags=["collections"]
|
|
)
|
|
async def get_collection_public_replay(
|
|
response: Response,
|
|
coll_id: UUID,
|
|
org: Organization = Depends(org_public),
|
|
) -> CollOut:
|
|
coll = await colls.get_collection(
|
|
coll_id, org, resources=True, public_only=True
|
|
)
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
response.headers["Access-Control-Allow-Headers"] = "*"
|
|
return coll
|
|
|
|
@app.options(
|
|
"/orgs/{oid}/collections/{coll_id}/public/replay.json", tags=["collections"]
|
|
)
|
|
async def get_replay_preflight(response: Response):
|
|
response.headers["Access-Control-Allow-Methods"] = "GET, HEAD, OPTIONS"
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
response.headers["Access-Control-Allow-Headers"] = "*"
|
|
return {}
|
|
|
|
@app.patch("/orgs/{oid}/collections/{coll_id}", tags=["collections"])
|
|
async def update_collection(
|
|
coll_id: UUID,
|
|
update: UpdateColl,
|
|
org: Organization = Depends(org_crawl_dep),
|
|
):
|
|
return await colls.update_collection(coll_id, org, update)
|
|
|
|
@app.post(
|
|
"/orgs/{oid}/collections/{coll_id}/add",
|
|
tags=["collections"],
|
|
response_model=CollOut,
|
|
)
|
|
async def add_crawl_to_collection(
|
|
crawlList: AddRemoveCrawlList,
|
|
coll_id: UUID,
|
|
org: Organization = Depends(org_crawl_dep),
|
|
) -> CollOut:
|
|
return await colls.add_crawls_to_collection(coll_id, crawlList.crawlIds, org)
|
|
|
|
@app.post(
|
|
"/orgs/{oid}/collections/{coll_id}/remove",
|
|
tags=["collections"],
|
|
response_model=CollOut,
|
|
)
|
|
async def remove_crawl_from_collection(
|
|
crawlList: AddRemoveCrawlList,
|
|
coll_id: UUID,
|
|
org: Organization = Depends(org_crawl_dep),
|
|
) -> CollOut:
|
|
return await colls.remove_crawls_from_collection(
|
|
coll_id, crawlList.crawlIds, org
|
|
)
|
|
|
|
@app.delete(
|
|
"/orgs/{oid}/collections/{coll_id}",
|
|
tags=["collections"],
|
|
)
|
|
async def delete_collection(
|
|
coll_id: UUID, org: Organization = Depends(org_crawl_dep)
|
|
):
|
|
return await colls.delete_collection(coll_id, org)
|
|
|
|
@app.get("/orgs/{oid}/collections/{coll_id}/download", tags=["collections"])
|
|
async def download_collection(
|
|
coll_id: UUID, org: Organization = Depends(org_viewer_dep)
|
|
):
|
|
return await colls.download_collection(coll_id, org)
|
|
|
|
return colls
|