Add Org Check for Collection access (#2616)

Ensure collection access checks org membership
This commit is contained in:
Ilya Kreymer 2025-05-20 15:30:22 -07:00 committed by GitHub
parent e29db33629
commit 86e35e358d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 31 deletions

View File

@ -150,7 +150,7 @@ class CollectionOps:
if crawl_ids: if crawl_ids:
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org) await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
await self.update_collection_counts_and_tags(coll_id) await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id) await self.update_collection_dates(coll_id, org.id)
asyncio.create_task( asyncio.create_task(
self.event_webhook_ops.create_added_to_collection_notification( self.event_webhook_ops.create_added_to_collection_notification(
crawl_ids, coll_id, org crawl_ids, coll_id, org
@ -179,7 +179,7 @@ class CollectionOps:
if name_update or slug_update: if name_update or slug_update:
# If we're updating slug, save old one to previousSlugs to support redirects # If we're updating slug, save old one to previousSlugs to support redirects
coll = await self.get_collection(coll_id) coll = await self.get_collection(coll_id, org.id)
previous_slug = coll.slug previous_slug = coll.slug
if name_update and not slug_update: if name_update and not slug_update:
@ -237,7 +237,7 @@ class CollectionOps:
raise HTTPException(status_code=404, detail="collection_not_found") raise HTTPException(status_code=404, detail="collection_not_found")
await self.update_collection_counts_and_tags(coll_id) await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id) await self.update_collection_dates(coll_id, org.id)
asyncio.create_task( asyncio.create_task(
self.event_webhook_ops.create_added_to_collection_notification( self.event_webhook_ops.create_added_to_collection_notification(
@ -262,7 +262,7 @@ class CollectionOps:
raise HTTPException(status_code=404, detail="collection_not_found") raise HTTPException(status_code=404, detail="collection_not_found")
await self.update_collection_counts_and_tags(coll_id) await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id) await self.update_collection_dates(coll_id, org.id)
asyncio.create_task( asyncio.create_task(
self.event_webhook_ops.create_removed_from_collection_notification( self.event_webhook_ops.create_removed_from_collection_notification(
@ -273,10 +273,10 @@ class CollectionOps:
return await self.get_collection_out(coll_id, org) return await self.get_collection_out(coll_id, org)
async def get_collection_raw( async def get_collection_raw(
self, coll_id: UUID, public_or_unlisted_only: bool = False self, coll_id: UUID, oid: UUID, public_or_unlisted_only: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Get collection by id as dict from database""" """Get collection by id as dict from database"""
query: dict[str, object] = {"_id": coll_id} query: dict[str, object] = {"_id": coll_id, "oid": oid}
if public_or_unlisted_only: if public_or_unlisted_only:
query["access"] = {"$in": ["public", "unlisted"]} query["access"] = {"$in": ["public", "unlisted"]}
@ -308,10 +308,10 @@ class CollectionOps:
return result return result
async def get_collection( async def get_collection(
self, coll_id: UUID, public_or_unlisted_only: bool = False self, coll_id: UUID, oid: UUID, public_or_unlisted_only: bool = False
) -> Collection: ) -> Collection:
"""Get collection by id""" """Get collection by id"""
result = await self.get_collection_raw(coll_id, public_or_unlisted_only) result = await self.get_collection_raw(coll_id, oid, public_or_unlisted_only)
return Collection.from_dict(result) return Collection.from_dict(result)
async def get_collection_by_slug( async def get_collection_by_slug(
@ -344,7 +344,7 @@ class CollectionOps:
) -> CollOut: ) -> CollOut:
"""Get CollOut by id""" """Get CollOut by id"""
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
result = await self.get_collection_raw(coll_id, public_or_unlisted_only) result = await self.get_collection_raw(coll_id, org.id, public_or_unlisted_only)
if resources: if resources:
( (
@ -393,7 +393,7 @@ class CollectionOps:
allow_unlisted: bool = False, allow_unlisted: bool = False,
) -> PublicCollOut: ) -> PublicCollOut:
"""Get PublicCollOut by id""" """Get PublicCollOut by id"""
result = await self.get_collection_raw(coll_id) result = await self.get_collection_raw(coll_id, org.id)
result["orgName"] = org.name result["orgName"] = org.name
result["orgPublicProfile"] = org.enablePublicProfile result["orgPublicProfile"] = org.enablePublicProfile
@ -569,7 +569,7 @@ class CollectionOps:
match: dict[str, Any] match: dict[str, Any]
if coll_id: if coll_id:
crawl_ids = await self.get_collection_crawl_ids(coll_id) crawl_ids = await self.get_collection_crawl_ids(coll_id, org.id)
match = {"_id": {"$in": crawl_ids}} match = {"_id": {"$in": crawl_ids}}
else: else:
crawl_ids = [] crawl_ids = []
@ -600,13 +600,16 @@ class CollectionOps:
return {"names": names} return {"names": names}
async def get_collection_crawl_ids( async def get_collection_crawl_ids(
self, coll_id: UUID, public_or_unlisted_only=False self,
coll_id: UUID,
oid: UUID,
public_or_unlisted_only=False,
) -> List[str]: ) -> List[str]:
"""Return list of crawl ids in collection, including only public collections""" """Return list of crawl ids in collection, including only public collections"""
crawl_ids = [] crawl_ids = []
# ensure collection is public or unlisted, else throw here # ensure collection is public or unlisted, else throw here
if public_or_unlisted_only: if public_or_unlisted_only:
await self.get_collection_raw(coll_id, public_or_unlisted_only) await self.get_collection_raw(coll_id, oid, public_or_unlisted_only)
async for crawl_raw in self.crawls.find( async for crawl_raw in self.crawls.find(
{"collectionIds": coll_id}, projection=["_id"] {"collectionIds": coll_id}, projection=["_id"]
@ -654,7 +657,7 @@ class CollectionOps:
"""recalculate counts, tags and dates for all collections in an org""" """recalculate counts, tags and dates for all collections in an org"""
async for coll in self.collections.find({"oid": org.id}, projection={"_id": 1}): async for coll in self.collections.find({"oid": org.id}, projection={"_id": 1}):
await self.update_collection_counts_and_tags(coll.get("_id")) await self.update_collection_counts_and_tags(coll.get("_id"))
await self.update_collection_dates(coll.get("_id")) await self.update_collection_dates(coll.get("_id"), org.id)
async def update_collection_counts_and_tags(self, collection_id: UUID): async def update_collection_counts_and_tags(self, collection_id: UUID):
"""Set current crawl info in config when crawl begins""" """Set current crawl info in config when crawl begins"""
@ -721,11 +724,11 @@ class CollectionOps:
}, },
) )
async def update_collection_dates(self, coll_id: UUID): async def update_collection_dates(self, coll_id: UUID, oid: UUID):
"""Update collection earliest and latest dates from page timestamps""" """Update collection earliest and latest dates from page timestamps"""
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
coll = await self.get_collection(coll_id) coll = await self.get_collection(coll_id, oid)
crawl_ids = await self.get_collection_crawl_ids(coll_id) crawl_ids = await self.get_collection_crawl_ids(coll_id, oid)
earliest_ts = None earliest_ts = None
latest_ts = None latest_ts = None
@ -762,7 +765,7 @@ class CollectionOps:
}, },
) )
async def update_crawl_collections(self, crawl_id: str): async def update_crawl_collections(self, crawl_id: str, oid: UUID):
"""Update counts, dates, and modified for all collections in crawl""" """Update counts, dates, and modified for all collections in crawl"""
crawl = await self.crawls.find_one({"_id": crawl_id}) crawl = await self.crawls.find_one({"_id": crawl_id})
crawl_coll_ids = crawl.get("collectionIds") crawl_coll_ids = crawl.get("collectionIds")
@ -770,14 +773,16 @@ class CollectionOps:
for coll_id in crawl_coll_ids: for coll_id in crawl_coll_ids:
await self.update_collection_counts_and_tags(coll_id) await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id) await self.update_collection_dates(coll_id, oid)
await self.collections.find_one_and_update( await self.collections.find_one_and_update(
{"_id": coll_id}, {"_id": coll_id},
{"$set": {"modified": modified}}, {"$set": {"modified": modified}},
return_document=pymongo.ReturnDocument.AFTER, return_document=pymongo.ReturnDocument.AFTER,
) )
async def add_successful_crawl_to_collections(self, crawl_id: str, cid: UUID): async def add_successful_crawl_to_collections(
self, crawl_id: str, cid: UUID, oid: UUID
):
"""Add successful crawl to its auto-add collections.""" """Add successful crawl to its auto-add collections."""
workflow = await self.crawl_configs.find_one({"_id": cid}) workflow = await self.crawl_configs.find_one({"_id": cid})
auto_add_collections = workflow.get("autoAddCollections") auto_add_collections = workflow.get("autoAddCollections")
@ -786,7 +791,7 @@ class CollectionOps:
{"_id": crawl_id}, {"_id": crawl_id},
{"$set": {"collectionIds": auto_add_collections}}, {"$set": {"collectionIds": auto_add_collections}},
) )
await self.update_crawl_collections(crawl_id) await self.update_crawl_collections(crawl_id, oid)
async def get_org_public_collections( async def get_org_public_collections(
self, self,
@ -862,7 +867,7 @@ class CollectionOps:
source_page_id: Optional[UUID] = None, source_page_id: Optional[UUID] = None,
) -> Dict[str, bool]: ) -> Dict[str, bool]:
"""Upload file as stream to use as collection thumbnail""" """Upload file as stream to use as collection thumbnail"""
coll = await self.get_collection(coll_id) coll = await self.get_collection(coll_id, org.id)
_, extension = os.path.splitext(filename) _, extension = os.path.splitext(filename)
@ -936,7 +941,7 @@ class CollectionOps:
async def delete_thumbnail(self, coll_id: UUID, org: Organization): async def delete_thumbnail(self, coll_id: UUID, org: Organization):
"""Delete collection thumbnail""" """Delete collection thumbnail"""
coll = await self.get_collection(coll_id) coll = await self.get_collection(coll_id, org.id)
if not coll.thumbnail: if not coll.thumbnail:
raise HTTPException(status_code=404, detail="thumbnail_not_found") raise HTTPException(status_code=404, detail="thumbnail_not_found")

View File

@ -1565,7 +1565,9 @@ class CrawlOperator(BaseOperator):
crawl.oid, status.filesAddedSize, "crawl" crawl.oid, status.filesAddedSize, "crawl"
) )
await self.org_ops.set_last_crawl_finished(crawl.oid) await self.org_ops.set_last_crawl_finished(crawl.oid)
await self.coll_ops.add_successful_crawl_to_collections(crawl.id, crawl.cid) await self.coll_ops.add_successful_crawl_to_collections(
crawl.id, crawl.cid, crawl.oid
)
if state in FAILED_STATES: if state in FAILED_STATES:
await self.crawl_ops.delete_crawl_files(crawl.id, crawl.oid) await self.crawl_ops.delete_crawl_files(crawl.id, crawl.oid)

View File

@ -530,7 +530,6 @@ class PageOps:
coll_id: Optional[UUID] = None, coll_id: Optional[UUID] = None,
crawl_ids: Optional[List[str]] = None, crawl_ids: Optional[List[str]] = None,
public_or_unlisted_only=False, public_or_unlisted_only=False,
# pylint: disable=unused-argument
org: Optional[Organization] = None, org: Optional[Organization] = None,
search: Optional[str] = None, search: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
@ -568,8 +567,15 @@ class PageOps:
detail="only one of crawl_ids or coll_id can be provided", detail="only one of crawl_ids or coll_id can be provided",
) )
if not org:
raise HTTPException(
status_code=400, detail="org_missing_for_coll_pages"
)
crawl_ids = await self.coll_ops.get_collection_crawl_ids( crawl_ids = await self.coll_ops.get_collection_crawl_ids(
coll_id, public_or_unlisted_only coll_id,
org.id,
public_or_unlisted_only,
) )
if not crawl_ids: if not crawl_ids:
@ -741,12 +747,13 @@ class PageOps:
async def list_page_url_counts( async def list_page_url_counts(
self, self,
coll_id: UUID, coll_id: UUID,
oid: UUID,
url_prefix: Optional[str] = None, url_prefix: Optional[str] = None,
page_size: int = DEFAULT_PAGE_SIZE, page_size: int = DEFAULT_PAGE_SIZE,
) -> List[PageUrlCount]: ) -> List[PageUrlCount]:
"""List all page URLs in collection sorted desc by snapshot count """List all page URLs in collection sorted desc by snapshot count
unless prefix is specified""" unless prefix is specified"""
crawl_ids = await self.coll_ops.get_collection_crawl_ids(coll_id) crawl_ids = await self.coll_ops.get_collection_crawl_ids(coll_id, oid)
pages, _ = await self.list_pages( pages, _ = await self.list_pages(
crawl_ids=crawl_ids, crawl_ids=crawl_ids,
@ -1475,14 +1482,15 @@ def init_pages_api(
) )
async def get_collection_url_list( async def get_collection_url_list(
coll_id: UUID, coll_id: UUID,
# oid: UUID,
urlPrefix: Optional[str] = None, urlPrefix: Optional[str] = None,
pageSize: int = DEFAULT_PAGE_SIZE, pageSize: int = DEFAULT_PAGE_SIZE,
org: Organization = Depends(org_viewer_dep),
# page: int = 1, # page: int = 1,
): ):
"""Retrieve paginated list of urls in collection sorted by snapshot count""" """Retrieve paginated list of urls in collection sorted by snapshot count"""
pages = await ops.list_page_url_counts( pages = await ops.list_page_url_counts(
coll_id=coll_id, coll_id=coll_id,
oid=org.id,
url_prefix=urlPrefix, url_prefix=urlPrefix,
page_size=pageSize, page_size=pageSize,
) )

View File

@ -192,7 +192,7 @@ class UploadOps(BaseCrawlOps):
) )
asyncio.create_task( asyncio.create_task(
self._add_pages_and_update_collections(crawl_id, collections) self._add_pages_and_update_collections(crawl_id, org.id, collections)
) )
await self.orgs.inc_org_bytes_stored(org.id, file_size, "upload") await self.orgs.inc_org_bytes_stored(org.id, file_size, "upload")
@ -208,11 +208,11 @@ class UploadOps(BaseCrawlOps):
return {"id": crawl_id, "added": True, "storageQuotaReached": quota_reached} return {"id": crawl_id, "added": True, "storageQuotaReached": quota_reached}
async def _add_pages_and_update_collections( async def _add_pages_and_update_collections(
self, crawl_id: str, collections: Optional[List[str]] = None self, crawl_id: str, oid: UUID, collections: Optional[List[str]] = None
): ):
await self.page_ops.add_crawl_pages_to_db_from_wacz(crawl_id) await self.page_ops.add_crawl_pages_to_db_from_wacz(crawl_id)
if collections: if collections:
await self.colls.update_crawl_collections(crawl_id) await self.colls.update_crawl_collections(crawl_id, oid)
async def delete_uploads( async def delete_uploads(
self, self,

View File

@ -497,6 +497,21 @@ def test_collection_public(crawler_auth_headers, default_org_id):
assert r.headers["Access-Control-Allow-Origin"] == "*" assert r.headers["Access-Control-Allow-Origin"] == "*"
assert r.headers["Access-Control-Allow-Headers"] == "*" assert r.headers["Access-Control-Allow-Headers"] == "*"
def test_collection_wrong_org(admin_auth_headers, non_default_org_id):
r = requests.get(
f"{API_PREFIX}/orgs/{non_default_org_id}/collections/{_coll_id}/replay.json",
headers=admin_auth_headers,
)
assert r.status_code == 404
r = requests.get(
f"{API_PREFIX}/orgs/{non_default_org_id}/collections/{_coll_id}/public/replay.json",
)
assert r.status_code == 404
def test_collection_public_make_private(crawler_auth_headers, default_org_id):
# make private again # make private again
r = requests.patch( r = requests.patch(
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}", f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}",