542 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			542 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """ base crawl type """
 | |
| 
 | |
| import asyncio
 | |
| import uuid
 | |
| import os
 | |
| from datetime import datetime, timedelta
 | |
| from typing import Optional, Dict, List
 | |
| 
 | |
| from pydantic import BaseModel, UUID4
 | |
| from fastapi import HTTPException, Depends
 | |
| from .db import BaseMongoModel
 | |
| from .orgs import Organization
 | |
| from .pagination import PaginatedResponseModel, paginated_format, DEFAULT_PAGE_SIZE
 | |
| from .storages import get_presigned_url, delete_crawl_file_object
 | |
| from .users import User
 | |
| from .utils import dt_now
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class CrawlFile(BaseModel):
 | |
|     """file from a crawl"""
 | |
| 
 | |
|     filename: str
 | |
|     hash: str
 | |
|     size: int
 | |
|     def_storage_name: Optional[str]
 | |
| 
 | |
|     presignedUrl: Optional[str]
 | |
|     expireAt: Optional[datetime]
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class CrawlFileOut(BaseModel):
 | |
|     """output for file from a crawl (conformance to Data Resource Spec)"""
 | |
| 
 | |
|     name: str
 | |
|     path: str
 | |
|     hash: str
 | |
|     size: int
 | |
|     crawlId: Optional[str]
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class BaseCrawl(BaseMongoModel):
 | |
|     """Base Crawl object (representing crawls, uploads and manual sessions)"""
 | |
| 
 | |
|     id: str
 | |
| 
 | |
|     userid: UUID4
 | |
|     oid: UUID4
 | |
| 
 | |
|     started: datetime
 | |
|     finished: Optional[datetime]
 | |
| 
 | |
|     state: str
 | |
| 
 | |
|     stats: Optional[Dict[str, int]]
 | |
| 
 | |
|     files: Optional[List[CrawlFile]] = []
 | |
| 
 | |
|     notes: Optional[str]
 | |
| 
 | |
|     errors: Optional[List[str]] = []
 | |
| 
 | |
|     collections: Optional[List[UUID4]] = []
 | |
| 
 | |
|     fileSize: int = 0
 | |
|     fileCount: int = 0
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class BaseCrawlOut(BaseMongoModel):
 | |
|     """Base crawl output model"""
 | |
| 
 | |
|     # pylint: disable=duplicate-code
 | |
| 
 | |
|     type: Optional[str]
 | |
| 
 | |
|     id: str
 | |
| 
 | |
|     userid: UUID4
 | |
|     oid: UUID4
 | |
| 
 | |
|     userName: Optional[str]
 | |
| 
 | |
|     name: Optional[str]
 | |
|     description: Optional[str]
 | |
| 
 | |
|     started: datetime
 | |
|     finished: Optional[datetime]
 | |
| 
 | |
|     state: str
 | |
| 
 | |
|     stats: Optional[Dict[str, int]]
 | |
| 
 | |
|     fileSize: int = 0
 | |
|     fileCount: int = 0
 | |
| 
 | |
|     tags: Optional[List[str]] = []
 | |
| 
 | |
|     notes: Optional[str]
 | |
| 
 | |
|     errors: Optional[List[str]]
 | |
| 
 | |
|     collections: Optional[List[UUID4]] = []
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class BaseCrawlOutWithResources(BaseCrawlOut):
 | |
|     """includes resources"""
 | |
| 
 | |
|     resources: Optional[List[CrawlFileOut]] = []
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class UpdateCrawl(BaseModel):
 | |
|     """Update crawl"""
 | |
| 
 | |
|     tags: Optional[List[str]] = []
 | |
|     notes: Optional[str]
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class DeleteCrawlList(BaseModel):
 | |
|     """delete crawl list POST body"""
 | |
| 
 | |
|     crawl_ids: List[str]
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| class BaseCrawlOps:
 | |
|     """operations that apply to all crawls"""
 | |
| 
 | |
|     # pylint: disable=duplicate-code, too-many-arguments, too-many-locals
 | |
| 
 | |
|     def __init__(self, mdb, users, crawl_manager):
 | |
|         self.crawls = mdb["crawls"]
 | |
|         self.crawl_manager = crawl_manager
 | |
|         self.user_manager = users
 | |
| 
 | |
|         self.presign_duration_seconds = (
 | |
|             int(os.environ.get("PRESIGN_DURATION_MINUTES", 60)) * 60
 | |
|         )
 | |
| 
 | |
|     async def get_crawl_raw(
 | |
|         self,
 | |
|         crawlid: str,
 | |
|         org: Optional[Organization] = None,
 | |
|         type_: Optional[str] = None,
 | |
|     ):
 | |
|         """Get data for single crawl"""
 | |
| 
 | |
|         query = {"_id": crawlid}
 | |
|         if org:
 | |
|             query["oid"] = org.id
 | |
| 
 | |
|         if type_:
 | |
|             query["type"] = type_
 | |
| 
 | |
|         res = await self.crawls.find_one(query)
 | |
| 
 | |
|         if not res:
 | |
|             raise HTTPException(status_code=404, detail=f"Crawl not found: {crawlid}")
 | |
| 
 | |
|         return res
 | |
| 
 | |
|     async def get_crawl(
 | |
|         self,
 | |
|         crawlid: str,
 | |
|         org: Optional[Organization] = None,
 | |
|         type_: Optional[str] = None,
 | |
|     ):
 | |
|         """Get data for single base crawl"""
 | |
| 
 | |
|         res = await self.get_crawl_raw(crawlid, org, type_)
 | |
| 
 | |
|         if res.get("files"):
 | |
|             files = [CrawlFile(**data) for data in res["files"]]
 | |
| 
 | |
|             del res["files"]
 | |
| 
 | |
|             res["resources"] = await self._resolve_signed_urls(files, org, crawlid)
 | |
| 
 | |
|         del res["errors"]
 | |
| 
 | |
|         crawl = BaseCrawlOutWithResources.from_dict(res)
 | |
| 
 | |
|         user = await self.user_manager.get(crawl.userid)
 | |
|         if user:
 | |
|             # pylint: disable=invalid-name
 | |
|             crawl.userName = user.name
 | |
| 
 | |
|         return crawl
 | |
| 
 | |
|     async def get_resource_resolved_raw_crawl(
 | |
|         self, crawlid: str, org: Organization, type_=None
 | |
|     ):
 | |
|         """return single base crawl with resources resolved"""
 | |
|         res = await self.get_crawl_raw(crawlid=crawlid, type_=type_, org=org)
 | |
|         files = [CrawlFile(**data) for data in res["files"]]
 | |
|         res["resources"] = await self._resolve_signed_urls(files, org, res["_id"])
 | |
|         return res
 | |
| 
 | |
|     async def update_crawl(
 | |
|         self, crawl_id: str, org: Organization, update: UpdateCrawl, type_=None
 | |
|     ):
 | |
|         """Update existing crawl (tags and notes only for now)"""
 | |
|         update_values = update.dict(exclude_unset=True)
 | |
|         if len(update_values) == 0:
 | |
|             raise HTTPException(status_code=400, detail="no_update_data")
 | |
| 
 | |
|         query = {"_id": crawl_id, "oid": org.id}
 | |
|         if type_:
 | |
|             query["type"] = type_
 | |
| 
 | |
|         # update in db
 | |
|         result = await self.crawls.find_one_and_update(
 | |
|             query,
 | |
|             {"$set": update_values},
 | |
|         )
 | |
| 
 | |
|         if not result:
 | |
|             raise HTTPException(status_code=404, detail="crawl_not_found")
 | |
| 
 | |
|         return {"updated": True}
 | |
| 
 | |
|     async def delete_crawls(
 | |
|         self, org: Organization, delete_list: DeleteCrawlList, type_=None
 | |
|     ):
 | |
|         """Delete a list of crawls by id for given org"""
 | |
|         cids_to_update = set()
 | |
| 
 | |
|         size = 0
 | |
| 
 | |
|         for crawl_id in delete_list.crawl_ids:
 | |
|             crawl = await self.get_crawl_raw(crawl_id, org)
 | |
|             size += await self._delete_crawl_files(crawl, org)
 | |
|             if crawl.get("cid"):
 | |
|                 cids_to_update.add(crawl.get("cid"))
 | |
| 
 | |
|         query = {"_id": {"$in": delete_list.crawl_ids}, "oid": org.id}
 | |
|         if type_:
 | |
|             query["type"] = type_
 | |
| 
 | |
|         res = await self.crawls.delete_many(query)
 | |
| 
 | |
|         return res.deleted_count, size, cids_to_update
 | |
| 
 | |
|     async def _delete_crawl_files(self, crawl, org: Organization):
 | |
|         """Delete files associated with crawl from storage."""
 | |
|         crawl = BaseCrawl.from_dict(crawl)
 | |
|         size = 0
 | |
|         for file_ in crawl.files:
 | |
|             size += file_.size
 | |
|             status_code = await delete_crawl_file_object(org, file_, self.crawl_manager)
 | |
|             if status_code != 204:
 | |
|                 raise HTTPException(status_code=400, detail="file_deletion_error")
 | |
| 
 | |
|         return size
 | |
| 
 | |
|     async def _resolve_signed_urls(
 | |
|         self, files: List[CrawlFile], org: Organization, crawl_id: Optional[str] = None
 | |
|     ):
 | |
|         if not files:
 | |
|             print("no files")
 | |
|             return
 | |
| 
 | |
|         delta = timedelta(seconds=self.presign_duration_seconds)
 | |
| 
 | |
|         updates = []
 | |
|         out_files = []
 | |
| 
 | |
|         for file_ in files:
 | |
|             presigned_url = file_.presignedUrl
 | |
|             now = dt_now()
 | |
| 
 | |
|             if not presigned_url or now >= file_.expireAt:
 | |
|                 exp = now + delta
 | |
|                 presigned_url = await get_presigned_url(
 | |
|                     org, file_, self.crawl_manager, self.presign_duration_seconds
 | |
|                 )
 | |
|                 updates.append(
 | |
|                     (
 | |
|                         {"files.filename": file_.filename},
 | |
|                         {
 | |
|                             "$set": {
 | |
|                                 "files.$.presignedUrl": presigned_url,
 | |
|                                 "files.$.expireAt": exp,
 | |
|                             }
 | |
|                         },
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             out_files.append(
 | |
|                 CrawlFileOut(
 | |
|                     name=file_.filename,
 | |
|                     path=presigned_url,
 | |
|                     hash=file_.hash,
 | |
|                     size=file_.size,
 | |
|                     crawlId=crawl_id,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         if updates:
 | |
|             asyncio.create_task(self._update_presigned(updates))
 | |
| 
 | |
|         # print("presigned", out_files)
 | |
| 
 | |
|         return out_files
 | |
| 
 | |
|     async def _update_presigned(self, updates):
 | |
|         for update in updates:
 | |
|             await self.crawls.find_one_and_update(*update)
 | |
| 
 | |
|     async def add_to_collection(
 | |
|         self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID, org: Organization
 | |
|     ):
 | |
|         """Add crawls to collection."""
 | |
|         for crawl_id in crawl_ids:
 | |
|             crawl_raw = await self.get_crawl_raw(crawl_id, org)
 | |
|             crawl_collections = crawl_raw.get("collections")
 | |
|             if crawl_collections and crawl_id in crawl_collections:
 | |
|                 raise HTTPException(
 | |
|                     status_code=400, detail="crawl_already_in_collection"
 | |
|                 )
 | |
| 
 | |
|             await self.crawls.find_one_and_update(
 | |
|                 {"_id": crawl_id},
 | |
|                 {"$push": {"collections": collection_id}},
 | |
|             )
 | |
| 
 | |
|     async def remove_from_collection(
 | |
|         self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID
 | |
|     ):
 | |
|         """Remove crawls from collection."""
 | |
|         for crawl_id in crawl_ids:
 | |
|             await self.crawls.find_one_and_update(
 | |
|                 {"_id": crawl_id},
 | |
|                 {"$pull": {"collections": collection_id}},
 | |
|             )
 | |
| 
 | |
|     async def remove_collection_from_all_crawls(self, collection_id: uuid.UUID):
 | |
|         """Remove collection id from all crawls it's currently in."""
 | |
|         await self.crawls.update_many(
 | |
|             {"collections": collection_id},
 | |
|             {"$pull": {"collections": collection_id}},
 | |
|         )
 | |
| 
 | |
|     # pylint: disable=too-many-branches
 | |
|     async def list_all_base_crawls(
 | |
|         self,
 | |
|         org: Optional[Organization] = None,
 | |
|         userid: uuid.UUID = None,
 | |
|         name: str = None,
 | |
|         description: str = None,
 | |
|         collection_id: str = None,
 | |
|         states: Optional[List[str]] = None,
 | |
|         page_size: int = DEFAULT_PAGE_SIZE,
 | |
|         page: int = 1,
 | |
|         sort_by: str = None,
 | |
|         sort_direction: int = -1,
 | |
|         cls_type: type[BaseCrawlOut] = BaseCrawlOut,
 | |
|         type_=None,
 | |
|     ):
 | |
|         """List crawls of all types from the db"""
 | |
|         # Zero-index page for query
 | |
|         page = page - 1
 | |
|         skip = page * page_size
 | |
| 
 | |
|         oid = org.id if org else None
 | |
| 
 | |
|         query = {}
 | |
|         if type_:
 | |
|             query["type"] = type_
 | |
|         if oid:
 | |
|             query["oid"] = oid
 | |
| 
 | |
|         if userid:
 | |
|             query["userid"] = userid
 | |
| 
 | |
|         if states:
 | |
|             # validated_states = [value for value in state if value in ALL_CRAWL_STATES]
 | |
|             query["state"] = {"$in": states}
 | |
| 
 | |
|         aggregate = [{"$match": query}]
 | |
| 
 | |
|         if name:
 | |
|             aggregate.extend([{"$match": {"name": name}}])
 | |
| 
 | |
|         if description:
 | |
|             aggregate.extend([{"$match": {"description": description}}])
 | |
| 
 | |
|         if collection_id:
 | |
|             aggregate.extend([{"$match": {"collections": {"$in": [collection_id]}}}])
 | |
| 
 | |
|         if sort_by:
 | |
|             if sort_by not in ("started", "finished"):
 | |
|                 raise HTTPException(status_code=400, detail="invalid_sort_by")
 | |
|             if sort_direction not in (1, -1):
 | |
|                 raise HTTPException(status_code=400, detail="invalid_sort_direction")
 | |
| 
 | |
|             aggregate.extend([{"$sort": {sort_by: sort_direction}}])
 | |
| 
 | |
|         aggregate.extend(
 | |
|             [
 | |
|                 {
 | |
|                     "$lookup": {
 | |
|                         "from": "users",
 | |
|                         "localField": "userid",
 | |
|                         "foreignField": "id",
 | |
|                         "as": "userName",
 | |
|                     },
 | |
|                 },
 | |
|                 {"$set": {"userName": {"$arrayElemAt": ["$userName.name", 0]}}},
 | |
|                 {
 | |
|                     "$facet": {
 | |
|                         "items": [
 | |
|                             {"$skip": skip},
 | |
|                             {"$limit": page_size},
 | |
|                         ],
 | |
|                         "total": [{"$count": "count"}],
 | |
|                     }
 | |
|                 },
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|         # Get total
 | |
|         cursor = self.crawls.aggregate(aggregate)
 | |
|         results = await cursor.to_list(length=1)
 | |
|         result = results[0]
 | |
|         items = result["items"]
 | |
| 
 | |
|         try:
 | |
|             total = int(result["total"][0]["count"])
 | |
|         except (IndexError, ValueError):
 | |
|             total = 0
 | |
| 
 | |
|         crawls = []
 | |
|         for res in items:
 | |
|             files = None
 | |
|             if res.get("files"):
 | |
|                 files = [CrawlFile(**data) for data in res["files"]]
 | |
|                 del res["files"]
 | |
| 
 | |
|             crawl = cls_type.from_dict(res)
 | |
|             if hasattr(crawl, "resources"):
 | |
|                 # pylint: disable=attribute-defined-outside-init
 | |
|                 crawl.resources = await self._resolve_signed_urls(files, org, crawl.id)
 | |
| 
 | |
|             crawls.append(crawl)
 | |
| 
 | |
|         return crawls, total
 | |
| 
 | |
|     async def delete_crawls_all_types(
 | |
|         self, delete_list: DeleteCrawlList, org: Optional[Organization] = None
 | |
|     ):
 | |
|         """Delete uploaded crawls"""
 | |
|         deleted_count, _, _ = await self.delete_crawls(org, delete_list)
 | |
| 
 | |
|         if deleted_count < 1:
 | |
|             raise HTTPException(status_code=404, detail="crawl_not_found")
 | |
| 
 | |
|         return {"deleted": True}
 | |
| 
 | |
| 
 | |
| # ============================================================================
 | |
| def init_base_crawls_api(app, mdb, users, crawl_manager, orgs, user_dep):
 | |
|     """base crawls api"""
 | |
|     # pylint: disable=invalid-name, duplicate-code, too-many-arguments
 | |
| 
 | |
|     ops = BaseCrawlOps(mdb, users, crawl_manager)
 | |
| 
 | |
|     org_viewer_dep = orgs.org_viewer_dep
 | |
|     org_crawl_dep = orgs.org_crawl_dep
 | |
| 
 | |
|     @app.get(
 | |
|         "/orgs/{oid}/all-crawls",
 | |
|         tags=["all-crawls"],
 | |
|         response_model=PaginatedResponseModel,
 | |
|     )
 | |
|     async def list_all_base_crawls(
 | |
|         org: Organization = Depends(org_viewer_dep),
 | |
|         pageSize: int = DEFAULT_PAGE_SIZE,
 | |
|         page: int = 1,
 | |
|         userid: Optional[UUID4] = None,
 | |
|         name: Optional[str] = None,
 | |
|         state: Optional[str] = None,
 | |
|         description: Optional[str] = None,
 | |
|         collectionId: Optional[UUID4] = None,
 | |
|         sortBy: Optional[str] = "finished",
 | |
|         sortDirection: Optional[int] = -1,
 | |
|     ):
 | |
|         states = state.split(",") if state else None
 | |
|         crawls, total = await ops.list_all_base_crawls(
 | |
|             org,
 | |
|             userid=userid,
 | |
|             name=name,
 | |
|             description=description,
 | |
|             collection_id=collectionId,
 | |
|             states=states,
 | |
|             page_size=pageSize,
 | |
|             page=page,
 | |
|             sort_by=sortBy,
 | |
|             sort_direction=sortDirection,
 | |
|         )
 | |
|         return paginated_format(crawls, total, page, pageSize)
 | |
| 
 | |
|     @app.get(
 | |
|         "/orgs/{oid}/all-crawls/{crawlid}",
 | |
|         tags=["all-crawls"],
 | |
|         response_model=BaseCrawlOutWithResources,
 | |
|     )
 | |
|     async def get_base_crawl(crawlid: str, org: Organization = Depends(org_crawl_dep)):
 | |
|         res = await ops.get_resource_resolved_raw_crawl(crawlid, org)
 | |
|         return BaseCrawlOutWithResources.from_dict(res)
 | |
| 
 | |
|     @app.get(
 | |
|         "/orgs/all/all-crawls/{crawl_id}/replay.json",
 | |
|         tags=["all-crawls"],
 | |
|         response_model=BaseCrawlOutWithResources,
 | |
|     )
 | |
|     async def get_base_crawl_admin(crawl_id, user: User = Depends(user_dep)):
 | |
|         if not user.is_superuser:
 | |
|             raise HTTPException(status_code=403, detail="Not Allowed")
 | |
| 
 | |
|         return await ops.get_crawl(crawl_id, None)
 | |
| 
 | |
|     @app.get(
 | |
|         "/orgs/{oid}/all-crawls/{crawl_id}/replay.json",
 | |
|         tags=["all-crawls"],
 | |
|         response_model=BaseCrawlOutWithResources,
 | |
|     )
 | |
|     async def get_crawl(crawl_id, org: Organization = Depends(org_viewer_dep)):
 | |
|         return await ops.get_crawl(crawl_id, org)
 | |
| 
 | |
|     @app.post("/orgs/{oid}/all-crawls/delete", tags=["all-crawls"])
 | |
|     async def delete_crawls_all_types(
 | |
|         delete_list: DeleteCrawlList,
 | |
|         org: Organization = Depends(org_crawl_dep),
 | |
|     ):
 | |
|         return await ops.delete_crawls_all_types(delete_list, org)
 |