Add collection sorting and filtering (#863)

* Sort by name and description (ascending by default)
* Filter by name
* Add endpoint to fetch collection names for search
* Add collation so that utf-8 chars sort as expected
This commit is contained in:
Tessa Walsh 2023-05-22 16:53:49 -04:00 committed by GitHub
parent 821fbc12d8
commit 60fac2b677
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 155 additions and 12 deletions

View File

@ -49,6 +49,8 @@ class UpdateColl(BaseModel):
class CollectionOps:
"""ops for working with named collections of crawls"""
# pylint: disable=too-many-arguments
def __init__(self, mdb, crawls, crawl_manager, orgs):
self.collections = mdb["collections"]
@ -62,6 +64,11 @@ class CollectionOps:
[("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], unique=True
)
await self.collections.create_index(
[("oid", pymongo.ASCENDING), ("description", pymongo.ASCENDING)],
unique=True,
)
async def add_collection(
self,
oid: uuid.UUID,
@ -156,20 +163,62 @@ class CollectionOps:
return [result["name"] for result in results]
async def list_collections(
self, oid: uuid.UUID, page_size: int = DEFAULT_PAGE_SIZE, page: int = 1
self,
oid: uuid.UUID,
page_size: int = DEFAULT_PAGE_SIZE,
page: int = 1,
sort_by: str = None,
sort_direction: int = 1,
name: Optional[str] = None,
):
"""List all collections for org"""
# pylint: disable=too-many-locals
# Zero-index page for query
page = page - 1
skip = page * page_size
match_query = {"oid": oid}
total = await self.collections.count_documents(match_query)
if name:
match_query["name"] = name
cursor = self.collections.find(match_query, skip=skip, limit=page_size)
results = await cursor.to_list(length=page_size)
collections = [Collection.from_dict(res) for res in results]
aggregate = [{"$match": match_query}]
if sort_by:
if sort_by not in ("name", "description"):
raise HTTPException(status_code=400, detail="invalid_sort_by")
if sort_direction not in (1, -1):
raise HTTPException(status_code=400, detail="invalid_sort_direction")
aggregate.extend([{"$sort": {sort_by: sort_direction}}])
aggregate.extend(
[
{
"$facet": {
"items": [
{"$skip": skip},
{"$limit": page_size},
],
"total": [{"$count": "count"}],
}
},
]
)
cursor = self.collections.aggregate(
aggregate, collation=pymongo.collation.Collation(locale="en")
)
results = await cursor.to_list(length=1)
result = results[0]
items = result["items"]
try:
total = int(result["total"][0]["count"])
except (IndexError, ValueError):
total = 0
collections = [Collection.from_dict(res) for res in items]
return collections, total
@ -193,12 +242,16 @@ class CollectionOps:
return {"resources": all_files}
async def get_collection_names(self, org: Organization):
"""Return list of collection names"""
return await self.collections.distinct("name", {"oid": org.id})
# ============================================================================
# pylint: disable=too-many-locals
def init_collections_api(app, mdb, crawls, orgs, crawl_manager):
"""init collections api"""
# pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument, too-many-arguments
colls = CollectionOps(mdb, crawls, crawl_manager, orgs)
@ -221,9 +274,17 @@ def init_collections_api(app, mdb, crawls, orgs, crawl_manager):
org: Organization = Depends(org_viewer_dep),
pageSize: int = DEFAULT_PAGE_SIZE,
page: int = 1,
sortBy: str = None,
sortDirection: int = 1,
name: Optional[str] = None,
):
collections, total = await colls.list_collections(
org.id, page_size=pageSize, page=page
org.id,
page_size=pageSize,
page=page,
sort_by=sortBy,
sort_direction=sortDirection,
name=name,
)
return paginated_format(collections, total, page, pageSize)
@ -247,10 +308,13 @@ def init_collections_api(app, mdb, crawls, orgs, crawl_manager):
return results
@app.get(
"/orgs/{oid}/collections/{coll_id}",
tags=["collections"],
)
@app.get("/orgs/{oid}/collections/names", tags=["collections"])
async def get_collection_names(
org: Organization = Depends(org_viewer_dep),
):
return await colls.get_collection_names(org)
@app.get("/orgs/{oid}/collections/{coll_id}", tags=["collections"])
async def get_collection_crawls(
coll_id: uuid.UUID, org: Organization = Depends(org_viewer_dep)
):

View File

@ -525,7 +525,7 @@ class CrawlConfigOps:
aggregate.extend([{"$match": {"firstSeed": first_seed}}])
if sort_by:
if sort_by not in ("created, modified, firstSeed, lastCrawlTime"):
if sort_by not in ("created", "modified", "firstSeed", "lastCrawlTime"):
raise HTTPException(status_code=400, detail="invalid_sort_by")
if sort_direction not in (1, -1):
raise HTTPException(status_code=400, detail="invalid_sort_direction")

View File

@ -187,3 +187,82 @@ def test_list_collections(
assert second_coll["oid"] == default_org_id
assert second_coll.get("description") is None
assert second_coll["crawlIds"] == [crawler_crawl_id]
def test_filter_sort_collections(
crawler_auth_headers, default_org_id, crawler_crawl_id, admin_crawl_id
):
# Test filtering by name
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections?name={SECOND_COLLECTION_NAME}",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["total"] == 1
items = data["items"]
assert len(items) == 1
coll = items[0]
assert coll["id"]
assert coll["name"] == SECOND_COLLECTION_NAME
assert coll["oid"] == default_org_id
assert coll.get("description") is None
assert coll["crawlIds"] == [crawler_crawl_id]
# Test sorting by name, ascending (default)
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=name",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["total"] == 2
items = data["items"]
assert items[0]["name"] == SECOND_COLLECTION_NAME
assert items[1]["name"] == UPDATED_NAME
# Test sorting by name, descending
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=name&sortDirection=-1",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["total"] == 2
items = data["items"]
assert items[0]["name"] == UPDATED_NAME
assert items[1]["name"] == SECOND_COLLECTION_NAME
# Test sorting by description, ascending (default)
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=description",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["total"] == 2
items = data["items"]
assert items[0]["name"] == SECOND_COLLECTION_NAME
assert items[0].get("description") is None
assert items[1]["name"] == UPDATED_NAME
assert items[1]["description"] == DESCRIPTION
# Test sorting by description, descending
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections?sortBy=description&sortDirection=-1",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["total"] == 2
items = data["items"]
assert items[0]["name"] == UPDATED_NAME
assert items[0]["description"] == DESCRIPTION
assert items[1]["name"] == SECOND_COLLECTION_NAME
assert items[1].get("description") is None