Add Org Check for Collection access (#2616)

Ensure collection access checks org membership
This commit is contained in:
Ilya Kreymer 2025-05-20 15:30:22 -07:00 committed by GitHub
parent e29db33629
commit 86e35e358d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 31 deletions

View File

@ -150,7 +150,7 @@ class CollectionOps:
if crawl_ids:
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id)
await self.update_collection_dates(coll_id, org.id)
asyncio.create_task(
self.event_webhook_ops.create_added_to_collection_notification(
crawl_ids, coll_id, org
@ -179,7 +179,7 @@ class CollectionOps:
if name_update or slug_update:
# If we're updating slug, save old one to previousSlugs to support redirects
coll = await self.get_collection(coll_id)
coll = await self.get_collection(coll_id, org.id)
previous_slug = coll.slug
if name_update and not slug_update:
@ -237,7 +237,7 @@ class CollectionOps:
raise HTTPException(status_code=404, detail="collection_not_found")
await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id)
await self.update_collection_dates(coll_id, org.id)
asyncio.create_task(
self.event_webhook_ops.create_added_to_collection_notification(
@ -262,7 +262,7 @@ class CollectionOps:
raise HTTPException(status_code=404, detail="collection_not_found")
await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id)
await self.update_collection_dates(coll_id, org.id)
asyncio.create_task(
self.event_webhook_ops.create_removed_from_collection_notification(
@ -273,10 +273,10 @@ class CollectionOps:
return await self.get_collection_out(coll_id, org)
async def get_collection_raw(
self, coll_id: UUID, public_or_unlisted_only: bool = False
self, coll_id: UUID, oid: UUID, public_or_unlisted_only: bool = False
) -> Dict[str, Any]:
"""Get collection by id as dict from database"""
query: dict[str, object] = {"_id": coll_id}
query: dict[str, object] = {"_id": coll_id, "oid": oid}
if public_or_unlisted_only:
query["access"] = {"$in": ["public", "unlisted"]}
@ -308,10 +308,10 @@ class CollectionOps:
return result
async def get_collection(
self, coll_id: UUID, public_or_unlisted_only: bool = False
self, coll_id: UUID, oid: UUID, public_or_unlisted_only: bool = False
) -> Collection:
"""Get collection by id"""
result = await self.get_collection_raw(coll_id, public_or_unlisted_only)
result = await self.get_collection_raw(coll_id, oid, public_or_unlisted_only)
return Collection.from_dict(result)
async def get_collection_by_slug(
@ -344,7 +344,7 @@ class CollectionOps:
) -> CollOut:
"""Get CollOut by id"""
# pylint: disable=too-many-locals
result = await self.get_collection_raw(coll_id, public_or_unlisted_only)
result = await self.get_collection_raw(coll_id, org.id, public_or_unlisted_only)
if resources:
(
@ -393,7 +393,7 @@ class CollectionOps:
allow_unlisted: bool = False,
) -> PublicCollOut:
"""Get PublicCollOut by id"""
result = await self.get_collection_raw(coll_id)
result = await self.get_collection_raw(coll_id, org.id)
result["orgName"] = org.name
result["orgPublicProfile"] = org.enablePublicProfile
@ -569,7 +569,7 @@ class CollectionOps:
match: dict[str, Any]
if coll_id:
crawl_ids = await self.get_collection_crawl_ids(coll_id)
crawl_ids = await self.get_collection_crawl_ids(coll_id, org.id)
match = {"_id": {"$in": crawl_ids}}
else:
crawl_ids = []
@ -600,13 +600,16 @@ class CollectionOps:
return {"names": names}
async def get_collection_crawl_ids(
self, coll_id: UUID, public_or_unlisted_only=False
self,
coll_id: UUID,
oid: UUID,
public_or_unlisted_only=False,
) -> List[str]:
"""Return list of crawl ids in collection, including only public collections"""
crawl_ids = []
# ensure collection is public or unlisted, else throw here
if public_or_unlisted_only:
await self.get_collection_raw(coll_id, public_or_unlisted_only)
await self.get_collection_raw(coll_id, oid, public_or_unlisted_only)
async for crawl_raw in self.crawls.find(
{"collectionIds": coll_id}, projection=["_id"]
@ -654,7 +657,7 @@ class CollectionOps:
"""recalculate counts, tags and dates for all collections in an org"""
async for coll in self.collections.find({"oid": org.id}, projection={"_id": 1}):
await self.update_collection_counts_and_tags(coll.get("_id"))
await self.update_collection_dates(coll.get("_id"))
await self.update_collection_dates(coll.get("_id"), org.id)
async def update_collection_counts_and_tags(self, collection_id: UUID):
"""Set current crawl info in config when crawl begins"""
@ -721,11 +724,11 @@ class CollectionOps:
},
)
async def update_collection_dates(self, coll_id: UUID):
async def update_collection_dates(self, coll_id: UUID, oid: UUID):
"""Update collection earliest and latest dates from page timestamps"""
# pylint: disable=too-many-locals
coll = await self.get_collection(coll_id)
crawl_ids = await self.get_collection_crawl_ids(coll_id)
coll = await self.get_collection(coll_id, oid)
crawl_ids = await self.get_collection_crawl_ids(coll_id, oid)
earliest_ts = None
latest_ts = None
@ -762,7 +765,7 @@ class CollectionOps:
},
)
async def update_crawl_collections(self, crawl_id: str):
async def update_crawl_collections(self, crawl_id: str, oid: UUID):
"""Update counts, dates, and modified for all collections in crawl"""
crawl = await self.crawls.find_one({"_id": crawl_id})
crawl_coll_ids = crawl.get("collectionIds")
@ -770,14 +773,16 @@ class CollectionOps:
for coll_id in crawl_coll_ids:
await self.update_collection_counts_and_tags(coll_id)
await self.update_collection_dates(coll_id)
await self.update_collection_dates(coll_id, oid)
await self.collections.find_one_and_update(
{"_id": coll_id},
{"$set": {"modified": modified}},
return_document=pymongo.ReturnDocument.AFTER,
)
async def add_successful_crawl_to_collections(self, crawl_id: str, cid: UUID):
async def add_successful_crawl_to_collections(
self, crawl_id: str, cid: UUID, oid: UUID
):
"""Add successful crawl to its auto-add collections."""
workflow = await self.crawl_configs.find_one({"_id": cid})
auto_add_collections = workflow.get("autoAddCollections")
@ -786,7 +791,7 @@ class CollectionOps:
{"_id": crawl_id},
{"$set": {"collectionIds": auto_add_collections}},
)
await self.update_crawl_collections(crawl_id)
await self.update_crawl_collections(crawl_id, oid)
async def get_org_public_collections(
self,
@ -862,7 +867,7 @@ class CollectionOps:
source_page_id: Optional[UUID] = None,
) -> Dict[str, bool]:
"""Upload file as stream to use as collection thumbnail"""
coll = await self.get_collection(coll_id)
coll = await self.get_collection(coll_id, org.id)
_, extension = os.path.splitext(filename)
@ -936,7 +941,7 @@ class CollectionOps:
async def delete_thumbnail(self, coll_id: UUID, org: Organization):
"""Delete collection thumbnail"""
coll = await self.get_collection(coll_id)
coll = await self.get_collection(coll_id, org.id)
if not coll.thumbnail:
raise HTTPException(status_code=404, detail="thumbnail_not_found")

View File

@ -1565,7 +1565,9 @@ class CrawlOperator(BaseOperator):
crawl.oid, status.filesAddedSize, "crawl"
)
await self.org_ops.set_last_crawl_finished(crawl.oid)
await self.coll_ops.add_successful_crawl_to_collections(crawl.id, crawl.cid)
await self.coll_ops.add_successful_crawl_to_collections(
crawl.id, crawl.cid, crawl.oid
)
if state in FAILED_STATES:
await self.crawl_ops.delete_crawl_files(crawl.id, crawl.oid)

View File

@ -530,7 +530,6 @@ class PageOps:
coll_id: Optional[UUID] = None,
crawl_ids: Optional[List[str]] = None,
public_or_unlisted_only=False,
# pylint: disable=unused-argument
org: Optional[Organization] = None,
search: Optional[str] = None,
url: Optional[str] = None,
@ -568,8 +567,15 @@ class PageOps:
detail="only one of crawl_ids or coll_id can be provided",
)
if not org:
raise HTTPException(
status_code=400, detail="org_missing_for_coll_pages"
)
crawl_ids = await self.coll_ops.get_collection_crawl_ids(
coll_id, public_or_unlisted_only
coll_id,
org.id,
public_or_unlisted_only,
)
if not crawl_ids:
@ -741,12 +747,13 @@ class PageOps:
async def list_page_url_counts(
self,
coll_id: UUID,
oid: UUID,
url_prefix: Optional[str] = None,
page_size: int = DEFAULT_PAGE_SIZE,
) -> List[PageUrlCount]:
"""List all page URLs in collection sorted desc by snapshot count
unless prefix is specified"""
crawl_ids = await self.coll_ops.get_collection_crawl_ids(coll_id)
crawl_ids = await self.coll_ops.get_collection_crawl_ids(coll_id, oid)
pages, _ = await self.list_pages(
crawl_ids=crawl_ids,
@ -1475,14 +1482,15 @@ def init_pages_api(
)
async def get_collection_url_list(
coll_id: UUID,
# oid: UUID,
urlPrefix: Optional[str] = None,
pageSize: int = DEFAULT_PAGE_SIZE,
org: Organization = Depends(org_viewer_dep),
# page: int = 1,
):
"""Retrieve paginated list of urls in collection sorted by snapshot count"""
pages = await ops.list_page_url_counts(
coll_id=coll_id,
oid=org.id,
url_prefix=urlPrefix,
page_size=pageSize,
)

View File

@ -192,7 +192,7 @@ class UploadOps(BaseCrawlOps):
)
asyncio.create_task(
self._add_pages_and_update_collections(crawl_id, collections)
self._add_pages_and_update_collections(crawl_id, org.id, collections)
)
await self.orgs.inc_org_bytes_stored(org.id, file_size, "upload")
@ -208,11 +208,11 @@ class UploadOps(BaseCrawlOps):
return {"id": crawl_id, "added": True, "storageQuotaReached": quota_reached}
async def _add_pages_and_update_collections(
self, crawl_id: str, collections: Optional[List[str]] = None
self, crawl_id: str, oid: UUID, collections: Optional[List[str]] = None
):
await self.page_ops.add_crawl_pages_to_db_from_wacz(crawl_id)
if collections:
await self.colls.update_crawl_collections(crawl_id)
await self.colls.update_crawl_collections(crawl_id, oid)
async def delete_uploads(
self,

View File

@ -497,6 +497,21 @@ def test_collection_public(crawler_auth_headers, default_org_id):
assert r.headers["Access-Control-Allow-Origin"] == "*"
assert r.headers["Access-Control-Allow-Headers"] == "*"
def test_collection_wrong_org(admin_auth_headers, non_default_org_id):
r = requests.get(
f"{API_PREFIX}/orgs/{non_default_org_id}/collections/{_coll_id}/replay.json",
headers=admin_auth_headers,
)
assert r.status_code == 404
r = requests.get(
f"{API_PREFIX}/orgs/{non_default_org_id}/collections/{_coll_id}/public/replay.json",
)
assert r.status_code == 404
def test_collection_public_make_private(crawler_auth_headers, default_org_id):
# make private again
r = requests.patch(
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}",