Also ensure name is not empty by adding minimum length of 1 Co-authored-by: Ilya Kreymer <ikreymer@gmail.com>
301 lines
9.5 KiB
Python
301 lines
9.5 KiB
Python
"""
|
|
Collections API
|
|
"""
|
|
import uuid
|
|
from typing import Optional, List
|
|
|
|
import pymongo
|
|
from fastapi import Depends, HTTPException
|
|
|
|
from pydantic import BaseModel, UUID4, Field
|
|
|
|
from .db import BaseMongoModel
|
|
from .orgs import Organization
|
|
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
|
|
|
|
|
|
# ============================================================================
|
|
class Collection(BaseMongoModel):
|
|
"""Org collection structure"""
|
|
|
|
name: str = Field(..., min_length=1)
|
|
|
|
oid: UUID4
|
|
|
|
crawlIds: Optional[List[str]] = []
|
|
|
|
description: Optional[str]
|
|
|
|
|
|
# ============================================================================
|
|
class CollIn(BaseModel):
|
|
"""Collection Passed in By User"""
|
|
|
|
name: str = Field(..., min_length=1)
|
|
description: Optional[str]
|
|
crawlIds: Optional[List[str]] = []
|
|
|
|
|
|
# ============================================================================
|
|
class UpdateColl(BaseModel):
|
|
"""Update collection"""
|
|
|
|
name: Optional[str]
|
|
crawlIds: Optional[List[str]] = []
|
|
description: Optional[str]
|
|
|
|
|
|
# ============================================================================
|
|
class CollectionOps:
|
|
"""ops for working with named collections of crawls"""
|
|
|
|
def __init__(self, mdb, crawls, crawl_manager, orgs):
|
|
self.collections = mdb["collections"]
|
|
|
|
self.crawls = crawls
|
|
self.crawl_manager = crawl_manager
|
|
self.orgs = orgs
|
|
|
|
async def init_index(self):
|
|
"""init lookup index"""
|
|
await self.collections.create_index(
|
|
[("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], unique=True
|
|
)
|
|
|
|
async def add_collection(
|
|
self,
|
|
oid: uuid.UUID,
|
|
name: str,
|
|
crawl_ids: Optional[List[str]],
|
|
description: str = None,
|
|
):
|
|
"""Add new collection"""
|
|
crawl_ids = crawl_ids if crawl_ids else []
|
|
coll_id = uuid.uuid4()
|
|
coll = Collection(
|
|
id=coll_id,
|
|
oid=oid,
|
|
name=name,
|
|
crawlIds=crawl_ids,
|
|
description=description,
|
|
)
|
|
try:
|
|
await self.collections.insert_one(coll.to_dict())
|
|
return {"added": {"id": coll_id, "name": name}}
|
|
except pymongo.errors.DuplicateKeyError:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(status_code=400, detail="collection_name_taken")
|
|
|
|
async def update_collection(self, coll_id: uuid.UUID, update: UpdateColl):
|
|
"""Update collection"""
|
|
query = update.dict(exclude_unset=True)
|
|
|
|
if len(query) == 0:
|
|
raise HTTPException(status_code=400, detail="no_update_data")
|
|
|
|
try:
|
|
result = await self.collections.find_one_and_update(
|
|
{"_id": coll_id},
|
|
{"$set": query},
|
|
return_document=pymongo.ReturnDocument.AFTER,
|
|
)
|
|
except pymongo.errors.DuplicateKeyError:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(status_code=400, detail="collection_name_taken")
|
|
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
return Collection.from_dict(result)
|
|
|
|
async def add_crawl_to_collection(self, coll_id: uuid.UUID, crawl_id: str):
|
|
"""Add crawl to collection"""
|
|
result = await self.collections.find_one_and_update(
|
|
{"_id": coll_id},
|
|
{"$push": {"crawlIds": crawl_id}},
|
|
return_document=pymongo.ReturnDocument.AFTER,
|
|
)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
return Collection.from_dict(result)
|
|
|
|
async def remove_crawl_from_collection(self, coll_id: uuid.UUID, crawl_id: str):
|
|
"""Remove crawl from collection"""
|
|
result = await self.collections.find_one_and_update(
|
|
{"_id": coll_id},
|
|
{"$pull": {"crawlIds": crawl_id}},
|
|
return_document=pymongo.ReturnDocument.AFTER,
|
|
)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
return Collection.from_dict(result)
|
|
|
|
async def get_collection(self, coll_id: uuid.UUID):
|
|
"""Get collection by id"""
|
|
res = await self.collections.find_one({"_id": coll_id})
|
|
return Collection.from_dict(res) if res else None
|
|
|
|
async def find_collections(self, oid: uuid.UUID, names: List[str]):
|
|
"""Find all collections for org given a list of names"""
|
|
cursor = self.collections.find(
|
|
{"oid": oid, "name": {"$in": names}}, projection=["_id", "name"]
|
|
)
|
|
results = await cursor.to_list(length=1000)
|
|
if len(results) != len(names):
|
|
for result in results:
|
|
names.remove(result["name"])
|
|
|
|
if names:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Specified collection(s) not found: {', '.join(names)}",
|
|
)
|
|
|
|
return [result["name"] for result in results]
|
|
|
|
async def list_collections(
|
|
self, oid: uuid.UUID, page_size: int = DEFAULT_PAGE_SIZE, page: int = 1
|
|
):
|
|
"""List all collections for org"""
|
|
# Zero-index page for query
|
|
page = page - 1
|
|
skip = page * page_size
|
|
|
|
match_query = {"oid": oid}
|
|
|
|
total = await self.collections.count_documents(match_query)
|
|
|
|
cursor = self.collections.find(match_query, skip=skip, limit=page_size)
|
|
results = await cursor.to_list(length=page_size)
|
|
collections = [Collection.from_dict(res) for res in results]
|
|
|
|
return collections, total
|
|
|
|
async def get_collection_crawls(self, coll_id: uuid.UUID, oid: uuid.UUID):
|
|
"""Find collection and get all crawl resources"""
|
|
|
|
coll = await self.get_collection(coll_id)
|
|
if not coll:
|
|
raise HTTPException(status_code=404, detail="collection_not_found")
|
|
|
|
all_files = []
|
|
|
|
for crawl_id in coll.crawlIds:
|
|
org = await self.orgs.get_org_by_id(oid)
|
|
crawl = await self.crawls.get_crawl(crawl_id, org)
|
|
if not crawl.resources:
|
|
continue
|
|
|
|
for resource in crawl.resources:
|
|
all_files.append(resource)
|
|
|
|
return {"resources": all_files}
|
|
|
|
|
|
# ============================================================================
|
|
# pylint: disable=too-many-locals
|
|
def init_collections_api(app, mdb, crawls, orgs, crawl_manager):
|
|
"""init collections api"""
|
|
# pylint: disable=invalid-name, unused-argument
|
|
|
|
colls = CollectionOps(mdb, crawls, crawl_manager, orgs)
|
|
|
|
org_crawl_dep = orgs.org_crawl_dep
|
|
org_viewer_dep = orgs.org_viewer_dep
|
|
|
|
@app.post("/orgs/{oid}/collections", tags=["collections"])
|
|
async def add_collection(
|
|
new_coll: CollIn, org: Organization = Depends(org_crawl_dep)
|
|
):
|
|
return await colls.add_collection(
|
|
org.id, new_coll.name, new_coll.crawlIds, new_coll.description
|
|
)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections",
|
|
tags=["collections"],
|
|
)
|
|
async def list_collection_all(
|
|
org: Organization = Depends(org_viewer_dep),
|
|
pageSize: int = DEFAULT_PAGE_SIZE,
|
|
page: int = 1,
|
|
):
|
|
collections, total = await colls.list_collections(
|
|
org.id, page_size=pageSize, page=page
|
|
)
|
|
return paginated_format(collections, total, page, pageSize)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/$all",
|
|
tags=["collections"],
|
|
)
|
|
async def get_collection_all(org: Organization = Depends(org_viewer_dep)):
|
|
results = {}
|
|
try:
|
|
all_collections, _ = colls.list_collections(org.id, page_size=10_000)
|
|
for collection in all_collections:
|
|
results[collection.name] = await colls.get_collection_crawls(
|
|
org.id, str(collection.id)
|
|
)
|
|
except Exception as exc:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(
|
|
status_code=400, detail="Error Listing All Crawled Files: " + str(exc)
|
|
)
|
|
|
|
return results
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/{coll_id}",
|
|
tags=["collections"],
|
|
)
|
|
async def get_collection_crawls(
|
|
coll_id: uuid.UUID, org: Organization = Depends(org_viewer_dep)
|
|
):
|
|
try:
|
|
results = await colls.get_collection_crawls(coll_id, org.id)
|
|
|
|
except Exception as exc:
|
|
# pylint: disable=raise-missing-from
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Error Listing Collection: {exc}"
|
|
)
|
|
|
|
return results
|
|
|
|
@app.post(
|
|
"/orgs/{oid}/collections/{coll_id}/update",
|
|
tags=["collections"],
|
|
response_model=Collection,
|
|
)
|
|
async def update_collection(
|
|
coll_id: uuid.UUID,
|
|
update: UpdateColl,
|
|
org: Organization = Depends(org_crawl_dep),
|
|
):
|
|
return await colls.update_collection(coll_id, update)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/{coll_id}/add",
|
|
tags=["collections"],
|
|
response_model=Collection,
|
|
)
|
|
async def add_crawl_to_collection(
|
|
crawlId: str, coll_id: uuid.UUID, org: Organization = Depends(org_crawl_dep)
|
|
):
|
|
return await colls.add_crawl_to_collection(coll_id, crawlId)
|
|
|
|
@app.get(
|
|
"/orgs/{oid}/collections/{coll_id}/remove",
|
|
tags=["collections"],
|
|
response_model=Collection,
|
|
)
|
|
async def remove_crawl_from_collection(
|
|
crawlId: str, coll_id: uuid.UUID, org: Organization = Depends(org_crawl_dep)
|
|
):
|
|
return await colls.remove_crawl_from_collection(coll_id, crawlId)
|
|
|
|
return colls
|