Add QA run aggregate stats API endpoint (#1682)
Fixes #1659 Takes an arbitrary set of thresholds for text and screenshot matches as a comma-separated list of floats. Returns a list of groupings for each that include the lower boundary and count for all thresholds passed in.
This commit is contained in:
parent
835014d829
commit
30ab139ff2
@ -35,6 +35,7 @@ from .models import (
|
||||
QARun,
|
||||
QARunOut,
|
||||
QARunWithResources,
|
||||
QARunAggregateStatsOut,
|
||||
DeleteQARunList,
|
||||
Organization,
|
||||
User,
|
||||
@ -917,6 +918,23 @@ class CrawlOps(BaseCrawlOps):
|
||||
|
||||
return QARunWithResources(**qa_run_dict)
|
||||
|
||||
async def get_qa_run_aggregate_stats(
|
||||
self,
|
||||
crawl_id: str,
|
||||
qa_run_id: str,
|
||||
thresholds: Dict[str, List[float]],
|
||||
) -> QARunAggregateStatsOut:
|
||||
"""Get aggregate stats for QA run"""
|
||||
screenshot_results = await self.page_ops.get_qa_run_aggregate_counts(
|
||||
crawl_id, qa_run_id, thresholds, key="screenshotMatch"
|
||||
)
|
||||
text_results = await self.page_ops.get_qa_run_aggregate_counts(
|
||||
crawl_id, qa_run_id, thresholds, key="textMatch"
|
||||
)
|
||||
return QARunAggregateStatsOut(
|
||||
screenshotMatch=screenshot_results, textMatch=text_results
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
async def recompute_crawl_file_count_and_size(crawls, crawl_id):
|
||||
@ -1125,6 +1143,37 @@ def init_crawls_api(crawl_manager: CrawlManager, app, user_dep, *args):
|
||||
):
|
||||
return await ops.get_qa_run_for_replay(crawl_id, qa_run_id, org)
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/crawls/{crawl_id}/qa/{qa_run_id}/stats",
|
||||
tags=["qa"],
|
||||
response_model=QARunAggregateStatsOut,
|
||||
)
|
||||
async def get_qa_run_aggregate_stats(
|
||||
crawl_id,
|
||||
qa_run_id,
|
||||
screenshotThresholds: str,
|
||||
textThresholds: str,
|
||||
# pylint: disable=unused-argument
|
||||
org: Organization = Depends(org_viewer_dep),
|
||||
):
|
||||
thresholds: Dict[str, List[float]] = {}
|
||||
try:
|
||||
thresholds["screenshotMatch"] = [
|
||||
float(threshold) for threshold in screenshotThresholds.split(",")
|
||||
]
|
||||
thresholds["textMatch"] = [
|
||||
float(threshold) for threshold in textThresholds.split(",")
|
||||
]
|
||||
# pylint: disable=broad-exception-caught,raise-missing-from
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="invalid_thresholds")
|
||||
|
||||
return await ops.get_qa_run_aggregate_stats(
|
||||
crawl_id,
|
||||
qa_run_id,
|
||||
thresholds,
|
||||
)
|
||||
|
||||
@app.post("/orgs/{oid}/crawls/{crawl_id}/qa/start", tags=["qa"])
|
||||
async def start_crawl_qa_run(
|
||||
crawl_id: str,
|
||||
|
@ -738,6 +738,22 @@ class QARunOut(BaseModel):
|
||||
stats: CrawlStats = CrawlStats()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class QARunBucketStats(BaseModel):
|
||||
"""Model for per-bucket aggregate stats results"""
|
||||
|
||||
lowerBoundary: str
|
||||
count: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class QARunAggregateStatsOut(BaseModel):
|
||||
"""QA Run aggregate stats out"""
|
||||
|
||||
screenshotMatch: List[QARunBucketStats]
|
||||
textMatch: List[QARunBucketStats]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class Crawl(BaseCrawl, CrawlConfigCore):
|
||||
"""Store State of a Crawl (Finished or Running)"""
|
||||
|
@ -22,6 +22,7 @@ from .models import (
|
||||
PageNoteIn,
|
||||
PageNoteEdit,
|
||||
PageNoteDelete,
|
||||
QARunBucketStats,
|
||||
)
|
||||
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
|
||||
from .utils import from_k8s_date, str_list_to_bools
|
||||
@ -514,6 +515,68 @@ class PageOps:
|
||||
for crawl_id in crawl_ids:
|
||||
await self.re_add_crawl_pages(crawl_id, oid)
|
||||
|
||||
async def get_qa_run_aggregate_counts(
|
||||
self,
|
||||
crawl_id: str,
|
||||
qa_run_id: str,
|
||||
thresholds: Dict[str, List[float]],
|
||||
key: str = "screenshotMatch",
|
||||
):
|
||||
"""Get counts for pages in QA run in buckets by score key based on thresholds"""
|
||||
boundaries = thresholds.get(key, [])
|
||||
if not boundaries:
|
||||
raise HTTPException(status_code=400, detail="missing_thresholds")
|
||||
|
||||
boundaries = sorted(boundaries)
|
||||
|
||||
# Make sure boundaries start with 0
|
||||
if boundaries[0] != 0:
|
||||
boundaries.insert(0, 0.0)
|
||||
|
||||
# Make sure we have upper boundary just over 1 to be inclusive of scores of 1
|
||||
if boundaries[-1] <= 1:
|
||||
boundaries.append(1.1)
|
||||
|
||||
aggregate = [
|
||||
{"$match": {"crawl_id": crawl_id}},
|
||||
{
|
||||
"$bucket": {
|
||||
"groupBy": f"$qa.{qa_run_id}.{key}",
|
||||
"default": "No data",
|
||||
"boundaries": boundaries,
|
||||
"output": {
|
||||
"count": {"$sum": 1},
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
cursor = self.pages.aggregate(aggregate)
|
||||
results = await cursor.to_list(length=len(boundaries))
|
||||
|
||||
return_data = []
|
||||
|
||||
for result in results:
|
||||
return_data.append(
|
||||
QARunBucketStats(
|
||||
lowerBoundary=str(result.get("_id")), count=result.get("count", 0)
|
||||
)
|
||||
)
|
||||
|
||||
# Add missing boundaries to result and re-sort
|
||||
for boundary in boundaries:
|
||||
if boundary < 1.0:
|
||||
matching_return_data = [
|
||||
bucket
|
||||
for bucket in return_data
|
||||
if bucket.lowerBoundary == str(boundary)
|
||||
]
|
||||
if not matching_return_data:
|
||||
return_data.append(
|
||||
QARunBucketStats(lowerBoundary=str(boundary), count=0)
|
||||
)
|
||||
|
||||
return sorted(return_data, key=lambda bucket: bucket.lowerBoundary)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
|
||||
|
@ -283,6 +283,69 @@ def test_qa_replay(
|
||||
assert data["resources"][0]["path"]
|
||||
|
||||
|
||||
def test_qa_stats(
|
||||
crawler_crawl_id,
|
||||
crawler_auth_headers,
|
||||
default_org_id,
|
||||
qa_run_id,
|
||||
qa_run_pages_ready,
|
||||
):
|
||||
# We'll want to improve this test by having more pages to test
|
||||
# if we can figure out stable page scores to test against
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7,0.9&textThresholds=0.7,0.9",
|
||||
headers=crawler_auth_headers,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
data = r.json()
|
||||
assert data["screenshotMatch"] == [
|
||||
{"lowerBoundary": "0.0", "count": 0},
|
||||
{"lowerBoundary": "0.7", "count": 0},
|
||||
{"lowerBoundary": "0.9", "count": 1},
|
||||
]
|
||||
assert data["textMatch"] == [
|
||||
{"lowerBoundary": "0.0", "count": 0},
|
||||
{"lowerBoundary": "0.7", "count": 0},
|
||||
{"lowerBoundary": "0.9", "count": 1},
|
||||
]
|
||||
|
||||
# Test we get expected results with explicit 0 boundary
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0,0.7,0.9&textThresholds=0,0.7,0.9",
|
||||
headers=crawler_auth_headers,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
data = r.json()
|
||||
assert data["screenshotMatch"] == [
|
||||
{"lowerBoundary": "0.0", "count": 0},
|
||||
{"lowerBoundary": "0.7", "count": 0},
|
||||
{"lowerBoundary": "0.9", "count": 1},
|
||||
]
|
||||
assert data["textMatch"] == [
|
||||
{"lowerBoundary": "0.0", "count": 0},
|
||||
{"lowerBoundary": "0.7", "count": 0},
|
||||
{"lowerBoundary": "0.9", "count": 1},
|
||||
]
|
||||
|
||||
# Test that missing threshold values result in 422 HTTPException
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7",
|
||||
headers=crawler_auth_headers,
|
||||
)
|
||||
assert r.status_code == 422
|
||||
assert r.json()["detail"][0]["msg"] == "field required"
|
||||
|
||||
# Test that invalid threshold values result in 400 HTTPException
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7&textThresholds=null",
|
||||
headers=crawler_auth_headers,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert r.json()["detail"] == "invalid_thresholds"
|
||||
|
||||
|
||||
def test_run_qa_not_running(
|
||||
crawler_crawl_id,
|
||||
crawler_auth_headers,
|
||||
|
Loading…
Reference in New Issue
Block a user