Add QA run aggregate stats API endpoint (#1682)
Fixes #1659 Takes an arbitrary set of thresholds for text and screenshot matches as a comma-separated list of floats. Returns a list of groupings for each that include the lower boundary and count for all thresholds passed in.
This commit is contained in:
		
							parent
							
								
									835014d829
								
							
						
					
					
						commit
						30ab139ff2
					
				| @ -35,6 +35,7 @@ from .models import ( | |||||||
|     QARun, |     QARun, | ||||||
|     QARunOut, |     QARunOut, | ||||||
|     QARunWithResources, |     QARunWithResources, | ||||||
|  |     QARunAggregateStatsOut, | ||||||
|     DeleteQARunList, |     DeleteQARunList, | ||||||
|     Organization, |     Organization, | ||||||
|     User, |     User, | ||||||
| @ -917,6 +918,23 @@ class CrawlOps(BaseCrawlOps): | |||||||
| 
 | 
 | ||||||
|         return QARunWithResources(**qa_run_dict) |         return QARunWithResources(**qa_run_dict) | ||||||
| 
 | 
 | ||||||
|  |     async def get_qa_run_aggregate_stats( | ||||||
|  |         self, | ||||||
|  |         crawl_id: str, | ||||||
|  |         qa_run_id: str, | ||||||
|  |         thresholds: Dict[str, List[float]], | ||||||
|  |     ) -> QARunAggregateStatsOut: | ||||||
|  |         """Get aggregate stats for QA run""" | ||||||
|  |         screenshot_results = await self.page_ops.get_qa_run_aggregate_counts( | ||||||
|  |             crawl_id, qa_run_id, thresholds, key="screenshotMatch" | ||||||
|  |         ) | ||||||
|  |         text_results = await self.page_ops.get_qa_run_aggregate_counts( | ||||||
|  |             crawl_id, qa_run_id, thresholds, key="textMatch" | ||||||
|  |         ) | ||||||
|  |         return QARunAggregateStatsOut( | ||||||
|  |             screenshotMatch=screenshot_results, textMatch=text_results | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| # ============================================================================ | # ============================================================================ | ||||||
| async def recompute_crawl_file_count_and_size(crawls, crawl_id): | async def recompute_crawl_file_count_and_size(crawls, crawl_id): | ||||||
| @ -1125,6 +1143,37 @@ def init_crawls_api(crawl_manager: CrawlManager, app, user_dep, *args): | |||||||
|     ): |     ): | ||||||
|         return await ops.get_qa_run_for_replay(crawl_id, qa_run_id, org) |         return await ops.get_qa_run_for_replay(crawl_id, qa_run_id, org) | ||||||
| 
 | 
 | ||||||
|  |     @app.get( | ||||||
|  |         "/orgs/{oid}/crawls/{crawl_id}/qa/{qa_run_id}/stats", | ||||||
|  |         tags=["qa"], | ||||||
|  |         response_model=QARunAggregateStatsOut, | ||||||
|  |     ) | ||||||
|  |     async def get_qa_run_aggregate_stats( | ||||||
|  |         crawl_id, | ||||||
|  |         qa_run_id, | ||||||
|  |         screenshotThresholds: str, | ||||||
|  |         textThresholds: str, | ||||||
|  |         # pylint: disable=unused-argument | ||||||
|  |         org: Organization = Depends(org_viewer_dep), | ||||||
|  |     ): | ||||||
|  |         thresholds: Dict[str, List[float]] = {} | ||||||
|  |         try: | ||||||
|  |             thresholds["screenshotMatch"] = [ | ||||||
|  |                 float(threshold) for threshold in screenshotThresholds.split(",") | ||||||
|  |             ] | ||||||
|  |             thresholds["textMatch"] = [ | ||||||
|  |                 float(threshold) for threshold in textThresholds.split(",") | ||||||
|  |             ] | ||||||
|  |         # pylint: disable=broad-exception-caught,raise-missing-from | ||||||
|  |         except Exception: | ||||||
|  |             raise HTTPException(status_code=400, detail="invalid_thresholds") | ||||||
|  | 
 | ||||||
|  |         return await ops.get_qa_run_aggregate_stats( | ||||||
|  |             crawl_id, | ||||||
|  |             qa_run_id, | ||||||
|  |             thresholds, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     @app.post("/orgs/{oid}/crawls/{crawl_id}/qa/start", tags=["qa"]) |     @app.post("/orgs/{oid}/crawls/{crawl_id}/qa/start", tags=["qa"]) | ||||||
|     async def start_crawl_qa_run( |     async def start_crawl_qa_run( | ||||||
|         crawl_id: str, |         crawl_id: str, | ||||||
|  | |||||||
| @ -738,6 +738,22 @@ class QARunOut(BaseModel): | |||||||
|     stats: CrawlStats = CrawlStats() |     stats: CrawlStats = CrawlStats() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # ============================================================================ | ||||||
|  | class QARunBucketStats(BaseModel): | ||||||
|  |     """Model for per-bucket aggregate stats results""" | ||||||
|  | 
 | ||||||
|  |     lowerBoundary: str | ||||||
|  |     count: int | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # ============================================================================ | ||||||
|  | class QARunAggregateStatsOut(BaseModel): | ||||||
|  |     """QA Run aggregate stats out""" | ||||||
|  | 
 | ||||||
|  |     screenshotMatch: List[QARunBucketStats] | ||||||
|  |     textMatch: List[QARunBucketStats] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # ============================================================================ | # ============================================================================ | ||||||
| class Crawl(BaseCrawl, CrawlConfigCore): | class Crawl(BaseCrawl, CrawlConfigCore): | ||||||
|     """Store State of a Crawl (Finished or Running)""" |     """Store State of a Crawl (Finished or Running)""" | ||||||
|  | |||||||
| @ -22,6 +22,7 @@ from .models import ( | |||||||
|     PageNoteIn, |     PageNoteIn, | ||||||
|     PageNoteEdit, |     PageNoteEdit, | ||||||
|     PageNoteDelete, |     PageNoteDelete, | ||||||
|  |     QARunBucketStats, | ||||||
| ) | ) | ||||||
| from .pagination import DEFAULT_PAGE_SIZE, paginated_format | from .pagination import DEFAULT_PAGE_SIZE, paginated_format | ||||||
| from .utils import from_k8s_date, str_list_to_bools | from .utils import from_k8s_date, str_list_to_bools | ||||||
| @ -514,6 +515,68 @@ class PageOps: | |||||||
|         for crawl_id in crawl_ids: |         for crawl_id in crawl_ids: | ||||||
|             await self.re_add_crawl_pages(crawl_id, oid) |             await self.re_add_crawl_pages(crawl_id, oid) | ||||||
| 
 | 
 | ||||||
|  |     async def get_qa_run_aggregate_counts( | ||||||
|  |         self, | ||||||
|  |         crawl_id: str, | ||||||
|  |         qa_run_id: str, | ||||||
|  |         thresholds: Dict[str, List[float]], | ||||||
|  |         key: str = "screenshotMatch", | ||||||
|  |     ): | ||||||
|  |         """Get counts for pages in QA run in buckets by score key based on thresholds""" | ||||||
|  |         boundaries = thresholds.get(key, []) | ||||||
|  |         if not boundaries: | ||||||
|  |             raise HTTPException(status_code=400, detail="missing_thresholds") | ||||||
|  | 
 | ||||||
|  |         boundaries = sorted(boundaries) | ||||||
|  | 
 | ||||||
|  |         # Make sure boundaries start with 0 | ||||||
|  |         if boundaries[0] != 0: | ||||||
|  |             boundaries.insert(0, 0.0) | ||||||
|  | 
 | ||||||
|  |         # Make sure we have upper boundary just over 1 to be inclusive of scores of 1 | ||||||
|  |         if boundaries[-1] <= 1: | ||||||
|  |             boundaries.append(1.1) | ||||||
|  | 
 | ||||||
|  |         aggregate = [ | ||||||
|  |             {"$match": {"crawl_id": crawl_id}}, | ||||||
|  |             { | ||||||
|  |                 "$bucket": { | ||||||
|  |                     "groupBy": f"$qa.{qa_run_id}.{key}", | ||||||
|  |                     "default": "No data", | ||||||
|  |                     "boundaries": boundaries, | ||||||
|  |                     "output": { | ||||||
|  |                         "count": {"$sum": 1}, | ||||||
|  |                     }, | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |         ] | ||||||
|  |         cursor = self.pages.aggregate(aggregate) | ||||||
|  |         results = await cursor.to_list(length=len(boundaries)) | ||||||
|  | 
 | ||||||
|  |         return_data = [] | ||||||
|  | 
 | ||||||
|  |         for result in results: | ||||||
|  |             return_data.append( | ||||||
|  |                 QARunBucketStats( | ||||||
|  |                     lowerBoundary=str(result.get("_id")), count=result.get("count", 0) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         # Add missing boundaries to result and re-sort | ||||||
|  |         for boundary in boundaries: | ||||||
|  |             if boundary < 1.0: | ||||||
|  |                 matching_return_data = [ | ||||||
|  |                     bucket | ||||||
|  |                     for bucket in return_data | ||||||
|  |                     if bucket.lowerBoundary == str(boundary) | ||||||
|  |                 ] | ||||||
|  |                 if not matching_return_data: | ||||||
|  |                     return_data.append( | ||||||
|  |                         QARunBucketStats(lowerBoundary=str(boundary), count=0) | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |         return sorted(return_data, key=lambda bucket: bucket.lowerBoundary) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| # ============================================================================ | # ============================================================================ | ||||||
| # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme | # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme | ||||||
|  | |||||||
| @ -283,6 +283,69 @@ def test_qa_replay( | |||||||
|     assert data["resources"][0]["path"] |     assert data["resources"][0]["path"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_qa_stats( | ||||||
|  |     crawler_crawl_id, | ||||||
|  |     crawler_auth_headers, | ||||||
|  |     default_org_id, | ||||||
|  |     qa_run_id, | ||||||
|  |     qa_run_pages_ready, | ||||||
|  | ): | ||||||
|  |     # We'll want to improve this test by having more pages to test | ||||||
|  |     # if we can figure out stable page scores to test against | ||||||
|  |     r = requests.get( | ||||||
|  |         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7,0.9&textThresholds=0.7,0.9", | ||||||
|  |         headers=crawler_auth_headers, | ||||||
|  |     ) | ||||||
|  |     assert r.status_code == 200 | ||||||
|  | 
 | ||||||
|  |     data = r.json() | ||||||
|  |     assert data["screenshotMatch"] == [ | ||||||
|  |         {"lowerBoundary": "0.0", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.7", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.9", "count": 1}, | ||||||
|  |     ] | ||||||
|  |     assert data["textMatch"] == [ | ||||||
|  |         {"lowerBoundary": "0.0", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.7", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.9", "count": 1}, | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     # Test we get expected results with explicit 0 boundary | ||||||
|  |     r = requests.get( | ||||||
|  |         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0,0.7,0.9&textThresholds=0,0.7,0.9", | ||||||
|  |         headers=crawler_auth_headers, | ||||||
|  |     ) | ||||||
|  |     assert r.status_code == 200 | ||||||
|  | 
 | ||||||
|  |     data = r.json() | ||||||
|  |     assert data["screenshotMatch"] == [ | ||||||
|  |         {"lowerBoundary": "0.0", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.7", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.9", "count": 1}, | ||||||
|  |     ] | ||||||
|  |     assert data["textMatch"] == [ | ||||||
|  |         {"lowerBoundary": "0.0", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.7", "count": 0}, | ||||||
|  |         {"lowerBoundary": "0.9", "count": 1}, | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     # Test that missing threshold values result in 422 HTTPException | ||||||
|  |     r = requests.get( | ||||||
|  |         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7", | ||||||
|  |         headers=crawler_auth_headers, | ||||||
|  |     ) | ||||||
|  |     assert r.status_code == 422 | ||||||
|  |     assert r.json()["detail"][0]["msg"] == "field required" | ||||||
|  | 
 | ||||||
|  |     # Test that invalid threshold values result in 400 HTTPException | ||||||
|  |     r = requests.get( | ||||||
|  |         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7&textThresholds=null", | ||||||
|  |         headers=crawler_auth_headers, | ||||||
|  |     ) | ||||||
|  |     assert r.status_code == 400 | ||||||
|  |     assert r.json()["detail"] == "invalid_thresholds" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_run_qa_not_running( | def test_run_qa_not_running( | ||||||
|     crawler_crawl_id, |     crawler_crawl_id, | ||||||
|     crawler_auth_headers, |     crawler_auth_headers, | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user