Add Org Check for Collection access (#2616)
Ensure collection access checks org membership
This commit is contained in:
parent
e29db33629
commit
86e35e358d
@ -150,7 +150,7 @@ class CollectionOps:
|
||||
if crawl_ids:
|
||||
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
|
||||
await self.update_collection_counts_and_tags(coll_id)
|
||||
await self.update_collection_dates(coll_id)
|
||||
await self.update_collection_dates(coll_id, org.id)
|
||||
asyncio.create_task(
|
||||
self.event_webhook_ops.create_added_to_collection_notification(
|
||||
crawl_ids, coll_id, org
|
||||
@ -179,7 +179,7 @@ class CollectionOps:
|
||||
|
||||
if name_update or slug_update:
|
||||
# If we're updating slug, save old one to previousSlugs to support redirects
|
||||
coll = await self.get_collection(coll_id)
|
||||
coll = await self.get_collection(coll_id, org.id)
|
||||
previous_slug = coll.slug
|
||||
|
||||
if name_update and not slug_update:
|
||||
@ -237,7 +237,7 @@ class CollectionOps:
|
||||
raise HTTPException(status_code=404, detail="collection_not_found")
|
||||
|
||||
await self.update_collection_counts_and_tags(coll_id)
|
||||
await self.update_collection_dates(coll_id)
|
||||
await self.update_collection_dates(coll_id, org.id)
|
||||
|
||||
asyncio.create_task(
|
||||
self.event_webhook_ops.create_added_to_collection_notification(
|
||||
@ -262,7 +262,7 @@ class CollectionOps:
|
||||
raise HTTPException(status_code=404, detail="collection_not_found")
|
||||
|
||||
await self.update_collection_counts_and_tags(coll_id)
|
||||
await self.update_collection_dates(coll_id)
|
||||
await self.update_collection_dates(coll_id, org.id)
|
||||
|
||||
asyncio.create_task(
|
||||
self.event_webhook_ops.create_removed_from_collection_notification(
|
||||
@ -273,10 +273,10 @@ class CollectionOps:
|
||||
return await self.get_collection_out(coll_id, org)
|
||||
|
||||
async def get_collection_raw(
|
||||
self, coll_id: UUID, public_or_unlisted_only: bool = False
|
||||
self, coll_id: UUID, oid: UUID, public_or_unlisted_only: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Get collection by id as dict from database"""
|
||||
query: dict[str, object] = {"_id": coll_id}
|
||||
query: dict[str, object] = {"_id": coll_id, "oid": oid}
|
||||
if public_or_unlisted_only:
|
||||
query["access"] = {"$in": ["public", "unlisted"]}
|
||||
|
||||
@ -308,10 +308,10 @@ class CollectionOps:
|
||||
return result
|
||||
|
||||
async def get_collection(
|
||||
self, coll_id: UUID, public_or_unlisted_only: bool = False
|
||||
self, coll_id: UUID, oid: UUID, public_or_unlisted_only: bool = False
|
||||
) -> Collection:
|
||||
"""Get collection by id"""
|
||||
result = await self.get_collection_raw(coll_id, public_or_unlisted_only)
|
||||
result = await self.get_collection_raw(coll_id, oid, public_or_unlisted_only)
|
||||
return Collection.from_dict(result)
|
||||
|
||||
async def get_collection_by_slug(
|
||||
@ -344,7 +344,7 @@ class CollectionOps:
|
||||
) -> CollOut:
|
||||
"""Get CollOut by id"""
|
||||
# pylint: disable=too-many-locals
|
||||
result = await self.get_collection_raw(coll_id, public_or_unlisted_only)
|
||||
result = await self.get_collection_raw(coll_id, org.id, public_or_unlisted_only)
|
||||
|
||||
if resources:
|
||||
(
|
||||
@ -393,7 +393,7 @@ class CollectionOps:
|
||||
allow_unlisted: bool = False,
|
||||
) -> PublicCollOut:
|
||||
"""Get PublicCollOut by id"""
|
||||
result = await self.get_collection_raw(coll_id)
|
||||
result = await self.get_collection_raw(coll_id, org.id)
|
||||
|
||||
result["orgName"] = org.name
|
||||
result["orgPublicProfile"] = org.enablePublicProfile
|
||||
@ -569,7 +569,7 @@ class CollectionOps:
|
||||
match: dict[str, Any]
|
||||
|
||||
if coll_id:
|
||||
crawl_ids = await self.get_collection_crawl_ids(coll_id)
|
||||
crawl_ids = await self.get_collection_crawl_ids(coll_id, org.id)
|
||||
match = {"_id": {"$in": crawl_ids}}
|
||||
else:
|
||||
crawl_ids = []
|
||||
@ -600,13 +600,16 @@ class CollectionOps:
|
||||
return {"names": names}
|
||||
|
||||
async def get_collection_crawl_ids(
|
||||
self, coll_id: UUID, public_or_unlisted_only=False
|
||||
self,
|
||||
coll_id: UUID,
|
||||
oid: UUID,
|
||||
public_or_unlisted_only=False,
|
||||
) -> List[str]:
|
||||
"""Return list of crawl ids in collection, including only public collections"""
|
||||
crawl_ids = []
|
||||
# ensure collection is public or unlisted, else throw here
|
||||
if public_or_unlisted_only:
|
||||
await self.get_collection_raw(coll_id, public_or_unlisted_only)
|
||||
await self.get_collection_raw(coll_id, oid, public_or_unlisted_only)
|
||||
|
||||
async for crawl_raw in self.crawls.find(
|
||||
{"collectionIds": coll_id}, projection=["_id"]
|
||||
@ -654,7 +657,7 @@ class CollectionOps:
|
||||
"""recalculate counts, tags and dates for all collections in an org"""
|
||||
async for coll in self.collections.find({"oid": org.id}, projection={"_id": 1}):
|
||||
await self.update_collection_counts_and_tags(coll.get("_id"))
|
||||
await self.update_collection_dates(coll.get("_id"))
|
||||
await self.update_collection_dates(coll.get("_id"), org.id)
|
||||
|
||||
async def update_collection_counts_and_tags(self, collection_id: UUID):
|
||||
"""Set current crawl info in config when crawl begins"""
|
||||
@ -721,11 +724,11 @@ class CollectionOps:
|
||||
},
|
||||
)
|
||||
|
||||
async def update_collection_dates(self, coll_id: UUID):
|
||||
async def update_collection_dates(self, coll_id: UUID, oid: UUID):
|
||||
"""Update collection earliest and latest dates from page timestamps"""
|
||||
# pylint: disable=too-many-locals
|
||||
coll = await self.get_collection(coll_id)
|
||||
crawl_ids = await self.get_collection_crawl_ids(coll_id)
|
||||
coll = await self.get_collection(coll_id, oid)
|
||||
crawl_ids = await self.get_collection_crawl_ids(coll_id, oid)
|
||||
|
||||
earliest_ts = None
|
||||
latest_ts = None
|
||||
@ -762,7 +765,7 @@ class CollectionOps:
|
||||
},
|
||||
)
|
||||
|
||||
async def update_crawl_collections(self, crawl_id: str):
|
||||
async def update_crawl_collections(self, crawl_id: str, oid: UUID):
|
||||
"""Update counts, dates, and modified for all collections in crawl"""
|
||||
crawl = await self.crawls.find_one({"_id": crawl_id})
|
||||
crawl_coll_ids = crawl.get("collectionIds")
|
||||
@ -770,14 +773,16 @@ class CollectionOps:
|
||||
|
||||
for coll_id in crawl_coll_ids:
|
||||
await self.update_collection_counts_and_tags(coll_id)
|
||||
await self.update_collection_dates(coll_id)
|
||||
await self.update_collection_dates(coll_id, oid)
|
||||
await self.collections.find_one_and_update(
|
||||
{"_id": coll_id},
|
||||
{"$set": {"modified": modified}},
|
||||
return_document=pymongo.ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
async def add_successful_crawl_to_collections(self, crawl_id: str, cid: UUID):
|
||||
async def add_successful_crawl_to_collections(
|
||||
self, crawl_id: str, cid: UUID, oid: UUID
|
||||
):
|
||||
"""Add successful crawl to its auto-add collections."""
|
||||
workflow = await self.crawl_configs.find_one({"_id": cid})
|
||||
auto_add_collections = workflow.get("autoAddCollections")
|
||||
@ -786,7 +791,7 @@ class CollectionOps:
|
||||
{"_id": crawl_id},
|
||||
{"$set": {"collectionIds": auto_add_collections}},
|
||||
)
|
||||
await self.update_crawl_collections(crawl_id)
|
||||
await self.update_crawl_collections(crawl_id, oid)
|
||||
|
||||
async def get_org_public_collections(
|
||||
self,
|
||||
@ -862,7 +867,7 @@ class CollectionOps:
|
||||
source_page_id: Optional[UUID] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""Upload file as stream to use as collection thumbnail"""
|
||||
coll = await self.get_collection(coll_id)
|
||||
coll = await self.get_collection(coll_id, org.id)
|
||||
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
@ -936,7 +941,7 @@ class CollectionOps:
|
||||
|
||||
async def delete_thumbnail(self, coll_id: UUID, org: Organization):
|
||||
"""Delete collection thumbnail"""
|
||||
coll = await self.get_collection(coll_id)
|
||||
coll = await self.get_collection(coll_id, org.id)
|
||||
|
||||
if not coll.thumbnail:
|
||||
raise HTTPException(status_code=404, detail="thumbnail_not_found")
|
||||
|
@ -1565,7 +1565,9 @@ class CrawlOperator(BaseOperator):
|
||||
crawl.oid, status.filesAddedSize, "crawl"
|
||||
)
|
||||
await self.org_ops.set_last_crawl_finished(crawl.oid)
|
||||
await self.coll_ops.add_successful_crawl_to_collections(crawl.id, crawl.cid)
|
||||
await self.coll_ops.add_successful_crawl_to_collections(
|
||||
crawl.id, crawl.cid, crawl.oid
|
||||
)
|
||||
|
||||
if state in FAILED_STATES:
|
||||
await self.crawl_ops.delete_crawl_files(crawl.id, crawl.oid)
|
||||
|
@ -530,7 +530,6 @@ class PageOps:
|
||||
coll_id: Optional[UUID] = None,
|
||||
crawl_ids: Optional[List[str]] = None,
|
||||
public_or_unlisted_only=False,
|
||||
# pylint: disable=unused-argument
|
||||
org: Optional[Organization] = None,
|
||||
search: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
@ -568,8 +567,15 @@ class PageOps:
|
||||
detail="only one of crawl_ids or coll_id can be provided",
|
||||
)
|
||||
|
||||
if not org:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="org_missing_for_coll_pages"
|
||||
)
|
||||
|
||||
crawl_ids = await self.coll_ops.get_collection_crawl_ids(
|
||||
coll_id, public_or_unlisted_only
|
||||
coll_id,
|
||||
org.id,
|
||||
public_or_unlisted_only,
|
||||
)
|
||||
|
||||
if not crawl_ids:
|
||||
@ -741,12 +747,13 @@ class PageOps:
|
||||
async def list_page_url_counts(
|
||||
self,
|
||||
coll_id: UUID,
|
||||
oid: UUID,
|
||||
url_prefix: Optional[str] = None,
|
||||
page_size: int = DEFAULT_PAGE_SIZE,
|
||||
) -> List[PageUrlCount]:
|
||||
"""List all page URLs in collection sorted desc by snapshot count
|
||||
unless prefix is specified"""
|
||||
crawl_ids = await self.coll_ops.get_collection_crawl_ids(coll_id)
|
||||
crawl_ids = await self.coll_ops.get_collection_crawl_ids(coll_id, oid)
|
||||
|
||||
pages, _ = await self.list_pages(
|
||||
crawl_ids=crawl_ids,
|
||||
@ -1475,14 +1482,15 @@ def init_pages_api(
|
||||
)
|
||||
async def get_collection_url_list(
|
||||
coll_id: UUID,
|
||||
# oid: UUID,
|
||||
urlPrefix: Optional[str] = None,
|
||||
pageSize: int = DEFAULT_PAGE_SIZE,
|
||||
org: Organization = Depends(org_viewer_dep),
|
||||
# page: int = 1,
|
||||
):
|
||||
"""Retrieve paginated list of urls in collection sorted by snapshot count"""
|
||||
pages = await ops.list_page_url_counts(
|
||||
coll_id=coll_id,
|
||||
oid=org.id,
|
||||
url_prefix=urlPrefix,
|
||||
page_size=pageSize,
|
||||
)
|
||||
|
@ -192,7 +192,7 @@ class UploadOps(BaseCrawlOps):
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self._add_pages_and_update_collections(crawl_id, collections)
|
||||
self._add_pages_and_update_collections(crawl_id, org.id, collections)
|
||||
)
|
||||
|
||||
await self.orgs.inc_org_bytes_stored(org.id, file_size, "upload")
|
||||
@ -208,11 +208,11 @@ class UploadOps(BaseCrawlOps):
|
||||
return {"id": crawl_id, "added": True, "storageQuotaReached": quota_reached}
|
||||
|
||||
async def _add_pages_and_update_collections(
|
||||
self, crawl_id: str, collections: Optional[List[str]] = None
|
||||
self, crawl_id: str, oid: UUID, collections: Optional[List[str]] = None
|
||||
):
|
||||
await self.page_ops.add_crawl_pages_to_db_from_wacz(crawl_id)
|
||||
if collections:
|
||||
await self.colls.update_crawl_collections(crawl_id)
|
||||
await self.colls.update_crawl_collections(crawl_id, oid)
|
||||
|
||||
async def delete_uploads(
|
||||
self,
|
||||
|
@ -497,6 +497,21 @@ def test_collection_public(crawler_auth_headers, default_org_id):
|
||||
assert r.headers["Access-Control-Allow-Origin"] == "*"
|
||||
assert r.headers["Access-Control-Allow-Headers"] == "*"
|
||||
|
||||
|
||||
def test_collection_wrong_org(admin_auth_headers, non_default_org_id):
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{non_default_org_id}/collections/{_coll_id}/replay.json",
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert r.status_code == 404
|
||||
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{non_default_org_id}/collections/{_coll_id}/public/replay.json",
|
||||
)
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_collection_public_make_private(crawler_auth_headers, default_org_id):
|
||||
# make private again
|
||||
r = requests.patch(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}",
|
||||
|
Loading…
Reference in New Issue
Block a user