Add QA run aggregate stats API endpoint (#1682)

Fixes #1659 

Takes an arbitrary set of thresholds for text and screenshot matches as
a comma-separated list of floats.

Returns a list of groupings for each that include the lower boundary and
count for all thresholds passed in.
This commit is contained in:
Tessa Walsh 2024-04-17 13:24:18 -04:00 committed by GitHub
parent 835014d829
commit 30ab139ff2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 191 additions and 0 deletions

View File

@ -35,6 +35,7 @@ from .models import (
QARun,
QARunOut,
QARunWithResources,
QARunAggregateStatsOut,
DeleteQARunList,
Organization,
User,
@ -917,6 +918,23 @@ class CrawlOps(BaseCrawlOps):
return QARunWithResources(**qa_run_dict)
async def get_qa_run_aggregate_stats(
self,
crawl_id: str,
qa_run_id: str,
thresholds: Dict[str, List[float]],
) -> QARunAggregateStatsOut:
"""Get aggregate stats for QA run"""
screenshot_results = await self.page_ops.get_qa_run_aggregate_counts(
crawl_id, qa_run_id, thresholds, key="screenshotMatch"
)
text_results = await self.page_ops.get_qa_run_aggregate_counts(
crawl_id, qa_run_id, thresholds, key="textMatch"
)
return QARunAggregateStatsOut(
screenshotMatch=screenshot_results, textMatch=text_results
)
# ============================================================================
async def recompute_crawl_file_count_and_size(crawls, crawl_id):
@ -1125,6 +1143,37 @@ def init_crawls_api(crawl_manager: CrawlManager, app, user_dep, *args):
):
return await ops.get_qa_run_for_replay(crawl_id, qa_run_id, org)
@app.get(
"/orgs/{oid}/crawls/{crawl_id}/qa/{qa_run_id}/stats",
tags=["qa"],
response_model=QARunAggregateStatsOut,
)
async def get_qa_run_aggregate_stats(
crawl_id,
qa_run_id,
screenshotThresholds: str,
textThresholds: str,
# pylint: disable=unused-argument
org: Organization = Depends(org_viewer_dep),
):
thresholds: Dict[str, List[float]] = {}
try:
thresholds["screenshotMatch"] = [
float(threshold) for threshold in screenshotThresholds.split(",")
]
thresholds["textMatch"] = [
float(threshold) for threshold in textThresholds.split(",")
]
# pylint: disable=broad-exception-caught,raise-missing-from
except Exception:
raise HTTPException(status_code=400, detail="invalid_thresholds")
return await ops.get_qa_run_aggregate_stats(
crawl_id,
qa_run_id,
thresholds,
)
@app.post("/orgs/{oid}/crawls/{crawl_id}/qa/start", tags=["qa"])
async def start_crawl_qa_run(
crawl_id: str,

View File

@ -738,6 +738,22 @@ class QARunOut(BaseModel):
stats: CrawlStats = CrawlStats()
# ============================================================================
class QARunBucketStats(BaseModel):
"""Model for per-bucket aggregate stats results"""
lowerBoundary: str
count: int
# ============================================================================
class QARunAggregateStatsOut(BaseModel):
"""QA Run aggregate stats out"""
screenshotMatch: List[QARunBucketStats]
textMatch: List[QARunBucketStats]
# ============================================================================
class Crawl(BaseCrawl, CrawlConfigCore):
"""Store State of a Crawl (Finished or Running)"""

View File

@ -22,6 +22,7 @@ from .models import (
PageNoteIn,
PageNoteEdit,
PageNoteDelete,
QARunBucketStats,
)
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
from .utils import from_k8s_date, str_list_to_bools
@ -514,6 +515,68 @@ class PageOps:
for crawl_id in crawl_ids:
await self.re_add_crawl_pages(crawl_id, oid)
async def get_qa_run_aggregate_counts(
self,
crawl_id: str,
qa_run_id: str,
thresholds: Dict[str, List[float]],
key: str = "screenshotMatch",
):
"""Get counts for pages in QA run in buckets by score key based on thresholds"""
boundaries = thresholds.get(key, [])
if not boundaries:
raise HTTPException(status_code=400, detail="missing_thresholds")
boundaries = sorted(boundaries)
# Make sure boundaries start with 0
if boundaries[0] != 0:
boundaries.insert(0, 0.0)
# Make sure we have upper boundary just over 1 to be inclusive of scores of 1
if boundaries[-1] <= 1:
boundaries.append(1.1)
aggregate = [
{"$match": {"crawl_id": crawl_id}},
{
"$bucket": {
"groupBy": f"$qa.{qa_run_id}.{key}",
"default": "No data",
"boundaries": boundaries,
"output": {
"count": {"$sum": 1},
},
}
},
]
cursor = self.pages.aggregate(aggregate)
results = await cursor.to_list(length=len(boundaries))
return_data = []
for result in results:
return_data.append(
QARunBucketStats(
lowerBoundary=str(result.get("_id")), count=result.get("count", 0)
)
)
# Add missing boundaries to result and re-sort
for boundary in boundaries:
if boundary < 1.0:
matching_return_data = [
bucket
for bucket in return_data
if bucket.lowerBoundary == str(boundary)
]
if not matching_return_data:
return_data.append(
QARunBucketStats(lowerBoundary=str(boundary), count=0)
)
return sorted(return_data, key=lambda bucket: bucket.lowerBoundary)
# ============================================================================
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme

View File

@ -283,6 +283,69 @@ def test_qa_replay(
assert data["resources"][0]["path"]
def test_qa_stats(
crawler_crawl_id,
crawler_auth_headers,
default_org_id,
qa_run_id,
qa_run_pages_ready,
):
# We'll want to improve this test by having more pages to test
# if we can figure out stable page scores to test against
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7,0.9&textThresholds=0.7,0.9",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["screenshotMatch"] == [
{"lowerBoundary": "0.0", "count": 0},
{"lowerBoundary": "0.7", "count": 0},
{"lowerBoundary": "0.9", "count": 1},
]
assert data["textMatch"] == [
{"lowerBoundary": "0.0", "count": 0},
{"lowerBoundary": "0.7", "count": 0},
{"lowerBoundary": "0.9", "count": 1},
]
# Test we get expected results with explicit 0 boundary
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0,0.7,0.9&textThresholds=0,0.7,0.9",
headers=crawler_auth_headers,
)
assert r.status_code == 200
data = r.json()
assert data["screenshotMatch"] == [
{"lowerBoundary": "0.0", "count": 0},
{"lowerBoundary": "0.7", "count": 0},
{"lowerBoundary": "0.9", "count": 1},
]
assert data["textMatch"] == [
{"lowerBoundary": "0.0", "count": 0},
{"lowerBoundary": "0.7", "count": 0},
{"lowerBoundary": "0.9", "count": 1},
]
# Test that missing threshold values result in 422 HTTPException
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7",
headers=crawler_auth_headers,
)
assert r.status_code == 422
assert r.json()["detail"][0]["msg"] == "field required"
# Test that invalid threshold values result in 400 HTTPException
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7&textThresholds=null",
headers=crawler_auth_headers,
)
assert r.status_code == 400
assert r.json()["detail"] == "invalid_thresholds"
def test_run_qa_not_running(
crawler_crawl_id,
crawler_auth_headers,