Add QA run aggregate stats API endpoint (#1682)
Fixes #1659 Takes an arbitrary set of thresholds for text and screenshot matches as a comma-separated list of floats. Returns a list of groupings for each that include the lower boundary and count for all thresholds passed in.
This commit is contained in:
		
							parent
							
								
									835014d829
								
							
						
					
					
						commit
						30ab139ff2
					
				| @ -35,6 +35,7 @@ from .models import ( | ||||
|     QARun, | ||||
|     QARunOut, | ||||
|     QARunWithResources, | ||||
|     QARunAggregateStatsOut, | ||||
|     DeleteQARunList, | ||||
|     Organization, | ||||
|     User, | ||||
| @ -917,6 +918,23 @@ class CrawlOps(BaseCrawlOps): | ||||
| 
 | ||||
|         return QARunWithResources(**qa_run_dict) | ||||
| 
 | ||||
|     async def get_qa_run_aggregate_stats( | ||||
|         self, | ||||
|         crawl_id: str, | ||||
|         qa_run_id: str, | ||||
|         thresholds: Dict[str, List[float]], | ||||
|     ) -> QARunAggregateStatsOut: | ||||
|         """Get aggregate stats for QA run""" | ||||
|         screenshot_results = await self.page_ops.get_qa_run_aggregate_counts( | ||||
|             crawl_id, qa_run_id, thresholds, key="screenshotMatch" | ||||
|         ) | ||||
|         text_results = await self.page_ops.get_qa_run_aggregate_counts( | ||||
|             crawl_id, qa_run_id, thresholds, key="textMatch" | ||||
|         ) | ||||
|         return QARunAggregateStatsOut( | ||||
|             screenshotMatch=screenshot_results, textMatch=text_results | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| # ============================================================================ | ||||
| async def recompute_crawl_file_count_and_size(crawls, crawl_id): | ||||
| @ -1125,6 +1143,37 @@ def init_crawls_api(crawl_manager: CrawlManager, app, user_dep, *args): | ||||
|     ): | ||||
|         return await ops.get_qa_run_for_replay(crawl_id, qa_run_id, org) | ||||
| 
 | ||||
|     @app.get( | ||||
|         "/orgs/{oid}/crawls/{crawl_id}/qa/{qa_run_id}/stats", | ||||
|         tags=["qa"], | ||||
|         response_model=QARunAggregateStatsOut, | ||||
|     ) | ||||
|     async def get_qa_run_aggregate_stats( | ||||
|         crawl_id, | ||||
|         qa_run_id, | ||||
|         screenshotThresholds: str, | ||||
|         textThresholds: str, | ||||
|         # pylint: disable=unused-argument | ||||
|         org: Organization = Depends(org_viewer_dep), | ||||
|     ): | ||||
|         thresholds: Dict[str, List[float]] = {} | ||||
|         try: | ||||
|             thresholds["screenshotMatch"] = [ | ||||
|                 float(threshold) for threshold in screenshotThresholds.split(",") | ||||
|             ] | ||||
|             thresholds["textMatch"] = [ | ||||
|                 float(threshold) for threshold in textThresholds.split(",") | ||||
|             ] | ||||
|         # pylint: disable=broad-exception-caught,raise-missing-from | ||||
|         except Exception: | ||||
|             raise HTTPException(status_code=400, detail="invalid_thresholds") | ||||
| 
 | ||||
|         return await ops.get_qa_run_aggregate_stats( | ||||
|             crawl_id, | ||||
|             qa_run_id, | ||||
|             thresholds, | ||||
|         ) | ||||
| 
 | ||||
|     @app.post("/orgs/{oid}/crawls/{crawl_id}/qa/start", tags=["qa"]) | ||||
|     async def start_crawl_qa_run( | ||||
|         crawl_id: str, | ||||
|  | ||||
| @ -738,6 +738,22 @@ class QARunOut(BaseModel): | ||||
|     stats: CrawlStats = CrawlStats() | ||||
| 
 | ||||
| 
 | ||||
| # ============================================================================ | ||||
| class QARunBucketStats(BaseModel): | ||||
|     """Model for per-bucket aggregate stats results""" | ||||
| 
 | ||||
|     lowerBoundary: str | ||||
|     count: int | ||||
| 
 | ||||
| 
 | ||||
| # ============================================================================ | ||||
| class QARunAggregateStatsOut(BaseModel): | ||||
|     """QA Run aggregate stats out""" | ||||
| 
 | ||||
|     screenshotMatch: List[QARunBucketStats] | ||||
|     textMatch: List[QARunBucketStats] | ||||
| 
 | ||||
| 
 | ||||
| # ============================================================================ | ||||
| class Crawl(BaseCrawl, CrawlConfigCore): | ||||
|     """Store State of a Crawl (Finished or Running)""" | ||||
|  | ||||
| @ -22,6 +22,7 @@ from .models import ( | ||||
|     PageNoteIn, | ||||
|     PageNoteEdit, | ||||
|     PageNoteDelete, | ||||
|     QARunBucketStats, | ||||
| ) | ||||
| from .pagination import DEFAULT_PAGE_SIZE, paginated_format | ||||
| from .utils import from_k8s_date, str_list_to_bools | ||||
| @ -514,6 +515,68 @@ class PageOps: | ||||
|         for crawl_id in crawl_ids: | ||||
|             await self.re_add_crawl_pages(crawl_id, oid) | ||||
| 
 | ||||
|     async def get_qa_run_aggregate_counts( | ||||
|         self, | ||||
|         crawl_id: str, | ||||
|         qa_run_id: str, | ||||
|         thresholds: Dict[str, List[float]], | ||||
|         key: str = "screenshotMatch", | ||||
|     ): | ||||
|         """Get counts for pages in QA run in buckets by score key based on thresholds""" | ||||
|         boundaries = thresholds.get(key, []) | ||||
|         if not boundaries: | ||||
|             raise HTTPException(status_code=400, detail="missing_thresholds") | ||||
| 
 | ||||
|         boundaries = sorted(boundaries) | ||||
| 
 | ||||
|         # Make sure boundaries start with 0 | ||||
|         if boundaries[0] != 0: | ||||
|             boundaries.insert(0, 0.0) | ||||
| 
 | ||||
|         # Make sure we have upper boundary just over 1 to be inclusive of scores of 1 | ||||
|         if boundaries[-1] <= 1: | ||||
|             boundaries.append(1.1) | ||||
| 
 | ||||
|         aggregate = [ | ||||
|             {"$match": {"crawl_id": crawl_id}}, | ||||
|             { | ||||
|                 "$bucket": { | ||||
|                     "groupBy": f"$qa.{qa_run_id}.{key}", | ||||
|                     "default": "No data", | ||||
|                     "boundaries": boundaries, | ||||
|                     "output": { | ||||
|                         "count": {"$sum": 1}, | ||||
|                     }, | ||||
|                 } | ||||
|             }, | ||||
|         ] | ||||
|         cursor = self.pages.aggregate(aggregate) | ||||
|         results = await cursor.to_list(length=len(boundaries)) | ||||
| 
 | ||||
|         return_data = [] | ||||
| 
 | ||||
|         for result in results: | ||||
|             return_data.append( | ||||
|                 QARunBucketStats( | ||||
|                     lowerBoundary=str(result.get("_id")), count=result.get("count", 0) | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         # Add missing boundaries to result and re-sort | ||||
|         for boundary in boundaries: | ||||
|             if boundary < 1.0: | ||||
|                 matching_return_data = [ | ||||
|                     bucket | ||||
|                     for bucket in return_data | ||||
|                     if bucket.lowerBoundary == str(boundary) | ||||
|                 ] | ||||
|                 if not matching_return_data: | ||||
|                     return_data.append( | ||||
|                         QARunBucketStats(lowerBoundary=str(boundary), count=0) | ||||
|                     ) | ||||
| 
 | ||||
|         return sorted(return_data, key=lambda bucket: bucket.lowerBoundary) | ||||
| 
 | ||||
| 
 | ||||
| # ============================================================================ | ||||
| # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme | ||||
|  | ||||
| @ -283,6 +283,69 @@ def test_qa_replay( | ||||
|     assert data["resources"][0]["path"] | ||||
| 
 | ||||
| 
 | ||||
| def test_qa_stats( | ||||
|     crawler_crawl_id, | ||||
|     crawler_auth_headers, | ||||
|     default_org_id, | ||||
|     qa_run_id, | ||||
|     qa_run_pages_ready, | ||||
| ): | ||||
|     # We'll want to improve this test by having more pages to test | ||||
|     # if we can figure out stable page scores to test against | ||||
|     r = requests.get( | ||||
|         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7,0.9&textThresholds=0.7,0.9", | ||||
|         headers=crawler_auth_headers, | ||||
|     ) | ||||
|     assert r.status_code == 200 | ||||
| 
 | ||||
|     data = r.json() | ||||
|     assert data["screenshotMatch"] == [ | ||||
|         {"lowerBoundary": "0.0", "count": 0}, | ||||
|         {"lowerBoundary": "0.7", "count": 0}, | ||||
|         {"lowerBoundary": "0.9", "count": 1}, | ||||
|     ] | ||||
|     assert data["textMatch"] == [ | ||||
|         {"lowerBoundary": "0.0", "count": 0}, | ||||
|         {"lowerBoundary": "0.7", "count": 0}, | ||||
|         {"lowerBoundary": "0.9", "count": 1}, | ||||
|     ] | ||||
| 
 | ||||
|     # Test we get expected results with explicit 0 boundary | ||||
|     r = requests.get( | ||||
|         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0,0.7,0.9&textThresholds=0,0.7,0.9", | ||||
|         headers=crawler_auth_headers, | ||||
|     ) | ||||
|     assert r.status_code == 200 | ||||
| 
 | ||||
|     data = r.json() | ||||
|     assert data["screenshotMatch"] == [ | ||||
|         {"lowerBoundary": "0.0", "count": 0}, | ||||
|         {"lowerBoundary": "0.7", "count": 0}, | ||||
|         {"lowerBoundary": "0.9", "count": 1}, | ||||
|     ] | ||||
|     assert data["textMatch"] == [ | ||||
|         {"lowerBoundary": "0.0", "count": 0}, | ||||
|         {"lowerBoundary": "0.7", "count": 0}, | ||||
|         {"lowerBoundary": "0.9", "count": 1}, | ||||
|     ] | ||||
| 
 | ||||
|     # Test that missing threshold values result in 422 HTTPException | ||||
|     r = requests.get( | ||||
|         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7", | ||||
|         headers=crawler_auth_headers, | ||||
|     ) | ||||
|     assert r.status_code == 422 | ||||
|     assert r.json()["detail"][0]["msg"] == "field required" | ||||
| 
 | ||||
|     # Test that invalid threshold values result in 400 HTTPException | ||||
|     r = requests.get( | ||||
|         f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawler_crawl_id}/qa/{qa_run_id}/stats?screenshotThresholds=0.7&textThresholds=null", | ||||
|         headers=crawler_auth_headers, | ||||
|     ) | ||||
|     assert r.status_code == 400 | ||||
|     assert r.json()["detail"] == "invalid_thresholds" | ||||
| 
 | ||||
| 
 | ||||
| def test_run_qa_not_running( | ||||
|     crawler_crawl_id, | ||||
|     crawler_auth_headers, | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user