Add API endpoint to retry all failed bg jobs (#1396)

Fixes #1395 

- Adds new `POST /orgs/<orgid>/jobs/retryFailed` API endpoint to retry all failed
background jobs for a specific org.
- Also adds `POST /orgs/all/jobs/retryFailed` for superadmin to retry all failed background jobs for all orgs
This commit is contained in:
Tessa Walsh 2023-12-05 16:00:45 -05:00 committed by GitHub
parent 26636f5386
commit 478b794f9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 3 deletions

View File

@ -21,6 +21,7 @@ from .models import (
PaginatedResponse, PaginatedResponse,
AnyJob, AnyJob,
StorageRef, StorageRef,
User,
) )
from .pagination import DEFAULT_PAGE_SIZE, paginated_format from .pagination import DEFAULT_PAGE_SIZE, paginated_format
@ -413,7 +414,7 @@ class BackgroundJobOps:
async def retry_background_job( async def retry_background_job(
self, job_id: str, org: Organization self, job_id: str, org: Organization
) -> Dict[str, Union[bool, Optional[str]]]: ) -> Dict[str, Union[bool, Optional[str]]]:
"""Retry background job and return new job id""" """Retry background job"""
job = await self.get_background_job(job_id, org.id) job = await self.get_background_job(job_id, org.id)
if not job: if not job:
raise HTTPException(status_code=404, detail="job_not_found") raise HTTPException(status_code=404, detail="job_not_found")
@ -455,11 +456,42 @@ class BackgroundJobOps:
return {"success": True} return {"success": True}
async def retry_failed_background_jobs(
self, org: Organization
) -> Dict[str, Union[bool, Optional[str]]]:
"""Retry all failed background jobs in an org
Keep track of tasks in set to prevent them from being garbage collected
See: https://stackoverflow.com/a/74059981
"""
bg_tasks = set()
async for job in self.jobs.find({"oid": org.id, "success": False}):
task = asyncio.create_task(self.retry_background_job(job["_id"], org))
bg_tasks.add(task)
task.add_done_callback(bg_tasks.discard)
return {"success": True}
async def retry_all_failed_background_jobs(
self,
) -> Dict[str, Union[bool, Optional[str]]]:
"""Retry all failed background jobs from all orgs
Keep track of tasks in set to prevent them from being garbage collected
See: https://stackoverflow.com/a/74059981
"""
bg_tasks = set()
async for job in self.jobs.find({"success": False}):
org = await self.org_ops.get_org_by_id(job["oid"])
task = asyncio.create_task(self.retry_background_job(job["_id"], org))
bg_tasks.add(task)
task.add_done_callback(bg_tasks.discard)
return {"success": True}
# ============================================================================ # ============================================================================
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
def init_background_jobs_api( def init_background_jobs_api(
mdb, email, user_manager, org_ops, crawl_manager, storage_ops app, mdb, email, user_manager, org_ops, crawl_manager, storage_ops, user_dep
): ):
"""init background jobs system""" """init background jobs system"""
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -494,6 +526,25 @@ def init_background_jobs_api(
"""Retry background job""" """Retry background job"""
return await ops.retry_background_job(job_id, org) return await ops.retry_background_job(job_id, org)
@app.post(
"/orgs/all/jobs/retryFailed",
)
async def retry_all_failed_background_jobs(user: User = Depends(user_dep)):
"""Retry failed background jobs from all orgs"""
if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed")
return await ops.retry_all_failed_background_jobs()
@router.post(
"/retryFailed",
)
async def retry_failed_background_jobs(
org: Organization = Depends(org_crawl_dep),
):
"""Retry failed background jobs"""
return await ops.retry_failed_background_jobs(org)
@router.get("", response_model=PaginatedResponse) @router.get("", response_model=PaginatedResponse)
async def list_background_jobs( async def list_background_jobs(
org: Organization = Depends(org_crawl_dep), org: Organization = Depends(org_crawl_dep),

View File

@ -90,7 +90,14 @@ def main():
storage_ops = init_storages_api(org_ops, crawl_manager) storage_ops = init_storages_api(org_ops, crawl_manager)
background_job_ops = init_background_jobs_api( background_job_ops = init_background_jobs_api(
mdb, email, user_manager, org_ops, crawl_manager, storage_ops app,
mdb,
email,
user_manager,
org_ops,
crawl_manager,
storage_ops,
current_active_user,
) )
profiles = init_profiles_api( profiles = init_profiles_api(

View File

@ -10,6 +10,9 @@ API_PREFIX = HOST_PREFIX + "/api"
ADMIN_USERNAME = "admin@example.com" ADMIN_USERNAME = "admin@example.com"
ADMIN_PW = "PASSW0RD!" ADMIN_PW = "PASSW0RD!"
CRAWLER_USERNAME = "crawlernightly@example.com"
CRAWLER_PW = "crawlerPASSWORD!"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def admin_auth_headers(): def admin_auth_headers():
@ -44,6 +47,32 @@ def default_org_id(admin_auth_headers):
time.sleep(5) time.sleep(5)
@pytest.fixture(scope="session")
def crawler_auth_headers(admin_auth_headers, default_org_id):
requests.post(
f"{API_PREFIX}/orgs/{default_org_id}/add-user",
json={
"email": CRAWLER_USERNAME,
"password": CRAWLER_PW,
"name": "new-crawler",
"description": "crawler test crawl",
"role": 20,
},
headers=admin_auth_headers,
)
r = requests.post(
f"{API_PREFIX}/auth/jwt/login",
data={
"username": CRAWLER_USERNAME,
"password": CRAWLER_PW,
"grant_type": "password",
},
)
data = r.json()
access_token = data.get("access_token")
return {"Authorization": f"Bearer {access_token}"}
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def crawl_id_wr(admin_auth_headers, default_org_id): def crawl_id_wr(admin_auth_headers, default_org_id):
# Start crawl. # Start crawl.

View File

@ -98,3 +98,11 @@ def test_get_background_job(admin_auth_headers, default_org_id):
assert data["object_type"] assert data["object_type"]
assert data["object_id"] assert data["object_id"]
assert data["replica_storage"] assert data["replica_storage"]
def test_retry_all_failed_bg_jobs_not_superuser(crawler_auth_headers):
r = requests.post(
f"{API_PREFIX}/orgs/all/jobs/retryFailed", headers=crawler_auth_headers
)
assert r.status_code == 403
assert r.json()["detail"] == "Not Allowed"