diff --git a/backend/btrixcloud/colls.py b/backend/btrixcloud/colls.py index 6723e356..633d5260 100644 --- a/backend/btrixcloud/colls.py +++ b/backend/btrixcloud/colls.py @@ -49,6 +49,8 @@ class UpdateColl(BaseModel): class CollectionOps: """ops for working with named collections of crawls""" + # pylint: disable=too-many-arguments + def __init__(self, mdb, crawls, crawl_manager, orgs): self.collections = mdb["collections"] @@ -62,6 +64,11 @@ class CollectionOps: [("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], unique=True ) + await self.collections.create_index( + [("oid", pymongo.ASCENDING), ("description", pymongo.ASCENDING)], + unique=True, + ) + async def add_collection( self, oid: uuid.UUID, @@ -156,20 +163,62 @@ class CollectionOps: return [result["name"] for result in results] async def list_collections( - self, oid: uuid.UUID, page_size: int = DEFAULT_PAGE_SIZE, page: int = 1 + self, + oid: uuid.UUID, + page_size: int = DEFAULT_PAGE_SIZE, + page: int = 1, + sort_by: str = None, + sort_direction: int = 1, + name: Optional[str] = None, ): """List all collections for org""" + # pylint: disable=too-many-locals # Zero-index page for query page = page - 1 skip = page * page_size match_query = {"oid": oid} - total = await self.collections.count_documents(match_query) + if name: + match_query["name"] = name - cursor = self.collections.find(match_query, skip=skip, limit=page_size) - results = await cursor.to_list(length=page_size) - collections = [Collection.from_dict(res) for res in results] + aggregate = [{"$match": match_query}] + + if sort_by: + if sort_by not in ("name", "description"): + raise HTTPException(status_code=400, detail="invalid_sort_by") + if sort_direction not in (1, -1): + raise HTTPException(status_code=400, detail="invalid_sort_direction") + + aggregate.extend([{"$sort": {sort_by: sort_direction}}]) + + aggregate.extend( + [ + { + "$facet": { + "items": [ + {"$skip": skip}, + {"$limit": page_size}, + ], + "total": [{"$count": "count"}], + } + }, + ] + ) + + cursor = self.collections.aggregate( + aggregate, collation=pymongo.collation.Collation(locale="en") + ) + results = await cursor.to_list(length=1) + result = results[0] + items = result["items"] + + try: + total = int(result["total"][0]["count"]) + except (IndexError, ValueError): + total = 0 + + collections = [Collection.from_dict(res) for res in items] return collections, total @@ -193,12 +242,16 @@ class CollectionOps: return {"resources": all_files} + async def get_collection_names(self, org: Organization): + """Return list of collection names""" + return await self.collections.distinct("name", {"oid": org.id}) + # ============================================================================ # pylint: disable=too-many-locals def init_collections_api(app, mdb, crawls, orgs, crawl_manager): """init collections api""" - # pylint: disable=invalid-name, unused-argument + # pylint: disable=invalid-name, unused-argument, too-many-arguments colls = CollectionOps(mdb, crawls, crawl_manager, orgs) @@ -221,9 +274,17 @@ def init_collections_api(app, mdb, crawls, orgs, crawl_manager): org: Organization = Depends(org_viewer_dep), pageSize: int = DEFAULT_PAGE_SIZE, page: int = 1, + sortBy: str = None, + sortDirection: int = 1, + name: Optional[str] = None, ): collections, total = await colls.list_collections( - org.id, page_size=pageSize, page=page + org.id, + page_size=pageSize, + page=page, + sort_by=sortBy, + sort_direction=sortDirection, + name=name, ) return paginated_format(collections, total, page, pageSize) @@ -247,10 +308,13 @@ def init_collections_api(app, mdb, crawls, orgs, crawl_manager): return results - @app.get( - "/orgs/{oid}/collections/{coll_id}", - tags=["collections"], - ) + @app.get("/orgs/{oid}/collections/names", tags=["collections"]) + async def get_collection_names( + org: Organization = Depends(org_viewer_dep), + ): + return await colls.get_collection_names(org) + + @app.get("/orgs/{oid}/collections/{coll_id}", tags=["collections"]) async def get_collection_crawls( coll_id: uuid.UUID, org: Organization = Depends(org_viewer_dep) ): diff --git a/backend/btrixcloud/crawlconfigs.py b/backend/btrixcloud/crawlconfigs.py index 75c23924..d2b81187 100644 --- a/backend/btrixcloud/crawlconfigs.py +++ b/backend/btrixcloud/crawlconfigs.py @@ -525,7 +525,7 @@ class CrawlConfigOps: aggregate.extend([{"$match": {"firstSeed": first_seed}}]) if sort_by: - if sort_by not in ("created, modified, firstSeed, lastCrawlTime"): + if sort_by not in ("created", "modified", "firstSeed", "lastCrawlTime"): raise HTTPException(status_code=400, detail="invalid_sort_by") if sort_direction not in (1, -1): raise HTTPException(status_code=400, detail="invalid_sort_direction") diff --git a/backend/test/test_collections.py b/backend/test/test_collections.py index afdec163..69ef3619 100644 --- a/backend/test/test_collections.py +++ b/backend/test/test_collections.py @@ -187,3 +187,82 @@ def test_list_collections( assert second_coll["oid"] == default_org_id assert second_coll.get("description") is None assert second_coll["crawlIds"] == [crawler_crawl_id] + + +def test_filter_sort_collections( + crawler_auth_headers, default_org_id, crawler_crawl_id, admin_crawl_id +): + # Test filtering by name + r = requests.get( + f"{API_PREFIX}/orgs/{default_org_id}/collections?name={SECOND_COLLECTION_NAME}", + headers=crawler_auth_headers, + ) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 1 + + items = data["items"] + assert len(items) == 1 + + coll = items[0] + assert coll["id"] + assert coll["name"] == SECOND_COLLECTION_NAME + assert coll["oid"] == default_org_id + assert coll.get("description") is None + assert coll["crawlIds"] == [crawler_crawl_id] + + # Test sorting by name, ascending (default) + r = requests.get( + f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=name", + headers=crawler_auth_headers, + ) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 2 + + items = data["items"] + assert items[0]["name"] == SECOND_COLLECTION_NAME + assert items[1]["name"] == UPDATED_NAME + + # Test sorting by name, descending + r = requests.get( + f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=name&sortDirection=-1", + headers=crawler_auth_headers, + ) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 2 + + items = data["items"] + assert items[0]["name"] == UPDATED_NAME + assert items[1]["name"] == SECOND_COLLECTION_NAME + + # Test sorting by description, ascending (default) + r = requests.get( + f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=description", + headers=crawler_auth_headers, + ) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 2 + + items = data["items"] + assert items[0]["name"] == SECOND_COLLECTION_NAME + assert items[0].get("description") is None + assert items[1]["name"] == UPDATED_NAME + assert items[1]["description"] == DESCRIPTION + + # Test sorting by description, descending + r = requests.get( + f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=description&sortDirection=-1", + headers=crawler_auth_headers, + ) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 2 + + items = data["items"] + assert items[0]["name"] == UPDATED_NAME + assert items[0]["description"] == DESCRIPTION + assert items[1]["name"] == SECOND_COLLECTION_NAME + assert items[1].get("description") is None