Add API endpoint to update crawl tags (#545)

* Add API endpoint to update crawls (tags only for now)
* Allow setting tags to empty list in crawlconfig updates
This commit is contained in:
Tessa Walsh 2023-02-01 22:24:36 -05:00 committed by GitHub
parent 23022193fb
commit 2e3b3cb228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 3 deletions

View File

@ -317,9 +317,7 @@ class CrawlConfigOps:
"""Update name, scale, schedule, and/or tags for an existing crawl config"""
# set update query
query = update.dict(
exclude_unset=True, exclude_defaults=True, exclude_none=True
)
query = update.dict(exclude_unset=True, exclude_none=True)
if len(query) == 0:
raise HTTPException(status_code=400, detail="no_update_data")

View File

@ -147,6 +147,13 @@ class CrawlCompleteIn(BaseModel):
completed: Optional[bool] = True
# ============================================================================
class UpdateCrawl(BaseModel):
"""Update crawl tags"""
tags: Optional[List[str]] = []
# ============================================================================
class CrawlOps:
"""Crawl Ops"""
@ -368,6 +375,25 @@ class CrawlOps:
# print(f"Crawl Already Added: {crawl.id} - {crawl.state}")
return False
async def update_crawl(self, crawl_id: str, org: Organization, update: UpdateCrawl):
"""Update existing crawl (tags only for now)"""
query = update.dict(exclude_unset=True, exclude_none=True)
if len(query) == 0:
raise HTTPException(status_code=400, detail="no_update_data")
# update in db
result = await self.crawls.find_one_and_update(
{"_id": crawl_id, "oid": org.id},
{"$set": query},
return_document=pymongo.ReturnDocument.AFTER,
)
if not result:
raise HTTPException(status_code=404, detail=f"Crawl '{crawl_id}' not found")
return {"success": True}
async def update_crawl_state(self, crawl_id: str, state: str):
"""called only when job container is being stopped/canceled"""
@ -680,6 +706,12 @@ def init_crawls_api(app, mdb, users, crawl_manager, crawl_config_ops, orgs, user
return crawls[0]
@app.patch("/orgs/{oid}/crawls/{crawl_id}", tags=["crawls"])
async def update_crawl(
update: UpdateCrawl, crawl_id: str, org: Organization = Depends(org_crawl_dep)
):
return await ops.update_crawl(crawl_id, org, update)
@app.post(
"/orgs/{oid}/crawls/{crawl_id}/scale",
tags=["crawls"],

View File

@ -40,3 +40,21 @@ def test_add_update_crawl_config(
data = r.json()
assert data["name"] == UPDATED_NAME
assert sorted(data["tags"]) == sorted(UPDATED_TAGS)
# Verify that deleting tags works as well
r = requests.patch(
f"{API_PREFIX}/orgs/{default_org_id}/crawlconfigs/{cid}/",
headers=crawler_auth_headers,
json={"tags": []},
)
assert r.status_code == 200
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawlconfigs/{cid}/",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["name"] == UPDATED_NAME
assert data["tags"] == []

View File

@ -96,3 +96,49 @@ def test_verify_wacz():
pages = z.open("pages/pages.jsonl").read().decode("utf-8")
assert '"https://webrecorder.net/"' in pages
def test_update_tags(admin_auth_headers, default_org_id, admin_crawl_id):
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{admin_crawl_id}",
headers=admin_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert sorted(data["tags"]) == ["wr-test-1", "wr-test-2"]
# Submit patch request to update tags
UPDATED_TAGS = ["wr-test-1-updated", "wr-test-2-updated"]
r = requests.patch(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{admin_crawl_id}",
headers=admin_auth_headers,
json={"tags": UPDATED_TAGS},
)
assert r.status_code == 200
data = r.json()
assert data["success"]
# Verify update was successful
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{admin_crawl_id}",
headers=admin_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert sorted(data["tags"]) == sorted(UPDATED_TAGS)
# Verify deleting all tags works as well
r = requests.patch(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{admin_crawl_id}",
headers=admin_auth_headers,
json={"tags": []},
)
assert r.status_code == 200
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{admin_crawl_id}",
headers=admin_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["tags"] == []