Optimize presigning for replay.json (#2516)

Fixes #2515.

This PR introduces a significantly optimized logic for presigning URLs
for crawls and collections.
- For collections, the files needed from all crawls are looked up, and
then the 'presign_urls' table is merged in one pass, resulting in a
unified iterator containing files and presign urls for those files.
- For crawls, the presign URLs are also looked up once, and the same
iterator is used for a single crawl with passed in list of CrawlFiles
- URLs that are already signed are added to the return list.
- For any remaining URLs to be signed, a bulk presigning function is
added, which shares an HTTP connection and signing 8 files in parallels
(customizable via helm chart, though may not be needed). This function
is used to call the presigning API in parallel.
This commit is contained in:
Ilya Kreymer 2025-05-20 12:09:35 -07:00 committed by GitHub
parent f1fd11c031
commit c134b576ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 283 additions and 88 deletions

View File

@ -1,7 +1,18 @@
"""base crawl type"""
from datetime import datetime
from typing import Optional, List, Union, Dict, Any, Type, TYPE_CHECKING, cast, Tuple
from typing import (
Optional,
List,
Union,
Dict,
Any,
Type,
TYPE_CHECKING,
cast,
Tuple,
AsyncIterable,
)
from uuid import UUID
import os
import urllib.parse
@ -76,6 +87,7 @@ class BaseCrawlOps:
background_job_ops: BackgroundJobOps,
):
self.crawls = mdb["crawls"]
self.presigned_urls = mdb["presigned_urls"]
self.crawl_configs = crawl_configs
self.user_manager = users
self.orgs = orgs
@ -464,29 +476,130 @@ class BaseCrawlOps:
) -> List[CrawlFileOut]:
"""Regenerate presigned URLs for files as necessary"""
if not files:
print("no files")
return []
out_files = []
for file_ in files:
presigned_url, expire_at = await self.storage_ops.get_presigned_url(
org, file_, force_update=force_update
cursor = self.presigned_urls.find(
{"_id": {"$in": [file.filename for file in files]}}
)
out_files.append(
presigned = await cursor.to_list(10000)
files_dict = [file.dict() for file in files]
# need an async generator to call bulk_presigned_files
async def async_gen():
yield {"presigned": presigned, "files": files_dict, "_id": crawl_id}
out_files, _ = await self.bulk_presigned_files(async_gen(), org, force_update)
return out_files
async def get_presigned_files(
self, match: dict[str, Any], org: Organization
) -> tuple[list[CrawlFileOut], bool]:
"""return presigned crawl files queried as batch, merging presigns with files in one pass"""
cursor = self.crawls.aggregate(
[
{"$match": match},
{"$project": {"files": "$files", "version": 1}},
{
"$lookup": {
"from": "presigned_urls",
"localField": "files.filename",
"foreignField": "_id",
"as": "presigned",
}
},
]
)
return await self.bulk_presigned_files(cursor, org)
async def bulk_presigned_files(
self,
cursor: AsyncIterable[dict[str, Any]],
org: Organization,
force_update=False,
) -> tuple[list[CrawlFileOut], bool]:
"""process presigned files in batches"""
resources = []
pages_optimized = False
sign_files = []
async for result in cursor:
pages_optimized = result.get("version") == 2
mapping = {}
# create mapping of filename -> file data
for file in result["files"]:
file["crawl_id"] = result["_id"]
mapping[file["filename"]] = file
if not force_update:
# add already presigned resources
for presigned in result["presigned"]:
file = mapping.get(presigned["_id"])
if file:
file["signedAt"] = presigned["signedAt"]
file["path"] = presigned["url"]
resources.append(
CrawlFileOut(
name=os.path.basename(file_.filename),
path=presigned_url or "",
hash=file_.hash,
size=file_.size,
crawlId=crawl_id,
numReplicas=len(file_.replicas) if file_.replicas else 0,
name=os.path.basename(file["filename"]),
path=presigned["url"],
hash=file["hash"],
size=file["size"],
crawlId=file["crawl_id"],
numReplicas=len(file.get("replicas") or []),
expireAt=date_to_str(
presigned["signedAt"]
+ self.storage_ops.signed_duration_delta
),
)
)
del mapping[presigned["_id"]]
sign_files.extend(list(mapping.values()))
by_storage: dict[str, dict] = {}
for file in sign_files:
storage_ref = StorageRef(**file.get("storage"))
sid = str(storage_ref)
storage_group = by_storage.get(sid)
if not storage_group:
storage_group = {"ref": storage_ref, "names": [], "files": []}
by_storage[sid] = storage_group
storage_group["names"].append(file["filename"])
storage_group["files"].append(file)
for storage_group in by_storage.values():
s3storage = self.storage_ops.get_org_storage_by_ref(
org, storage_group["ref"]
)
signed_urls, expire_at = await self.storage_ops.get_presigned_urls_bulk(
org, s3storage, storage_group["names"]
)
for url, file in zip(signed_urls, storage_group["files"]):
resources.append(
CrawlFileOut(
name=os.path.basename(file["filename"]),
path=url,
hash=file["hash"],
size=file["size"],
crawlId=file["crawl_id"],
numReplicas=len(file.get("replicas") or []),
expireAt=date_to_str(expire_at),
)
)
return out_files
return resources, pages_optimized
async def add_to_collection(
self, crawl_ids: List[str], collection_id: UUID, org: Organization

View File

@ -28,7 +28,6 @@ from .models import (
UpdateColl,
AddRemoveCrawlList,
BaseCrawl,
CrawlOutWithResources,
CrawlFileOut,
Organization,
PaginatedCollOutResponse,
@ -40,6 +39,7 @@ from .models import (
AddedResponse,
DeletedResponse,
CollectionSearchValuesResponse,
CollectionAllResponse,
OrgPublicCollections,
PublicOrgDetails,
CollAccessType,
@ -50,7 +50,12 @@ from .models import (
MIN_UPLOAD_PART_SIZE,
PublicCollOut,
)
from .utils import dt_now, slug_from_name, get_duplicate_key_error_field, get_origin
from .utils import (
dt_now,
slug_from_name,
get_duplicate_key_error_field,
get_origin,
)
if TYPE_CHECKING:
from .orgs import OrgOps
@ -346,7 +351,7 @@ class CollectionOps:
result["resources"],
crawl_ids,
pages_optimized,
) = await self.get_collection_crawl_resources(coll_id)
) = await self.get_collection_crawl_resources(coll_id, org)
initial_pages, _ = await self.page_ops.list_pages(
crawl_ids=crawl_ids,
@ -400,7 +405,9 @@ class CollectionOps:
if result.get("access") not in allowed_access:
raise HTTPException(status_code=404, detail="collection_not_found")
result["resources"], _, _ = await self.get_collection_crawl_resources(coll_id)
result["resources"], _, _ = await self.get_collection_crawl_resources(
coll_id, org
)
thumbnail = result.get("thumbnail")
if thumbnail:
@ -554,31 +561,23 @@ class CollectionOps:
return collections, total
# pylint: disable=too-many-locals
async def get_collection_crawl_resources(
self, coll_id: UUID
self, coll_id: Optional[UUID], org: Organization
) -> tuple[List[CrawlFileOut], List[str], bool]:
"""Return pre-signed resources for all collection crawl files."""
# Ensure collection exists
_ = await self.get_collection_raw(coll_id)
resources = []
pages_optimized = True
crawls, _ = await self.crawl_ops.list_all_base_crawls(
collection_id=coll_id,
states=list(SUCCESSFUL_STATES),
page_size=10_000,
cls_type=CrawlOutWithResources,
)
match: dict[str, Any]
if coll_id:
crawl_ids = await self.get_collection_crawl_ids(coll_id)
match = {"_id": {"$in": crawl_ids}}
else:
crawl_ids = []
match = {"oid": org.id}
for crawl in crawls:
crawl_ids.append(crawl.id)
if crawl.resources:
resources.extend(crawl.resources)
if crawl.version != 2:
pages_optimized = False
resources, pages_optimized = await self.crawl_ops.get_presigned_files(
match, org
)
return resources, crawl_ids, pages_optimized
@ -1009,24 +1008,11 @@ def init_collections_api(
@app.get(
"/orgs/{oid}/collections/$all",
tags=["collections"],
response_model=Dict[str, List[CrawlFileOut]],
response_model=CollectionAllResponse,
)
async def get_collection_all(org: Organization = Depends(org_viewer_dep)):
results = {}
try:
all_collections, _ = await colls.list_collections(org, page_size=10_000)
for collection in all_collections:
(
results[collection.name],
_,
_,
) = await colls.get_collection_crawl_resources(collection.id)
except Exception as exc:
# pylint: disable=raise-missing-from
raise HTTPException(
status_code=400, detail="Error Listing All Crawled Files: " + str(exc)
)
results["resources"] = await colls.get_collection_crawl_resources(None, org)
return results
@app.get(

View File

@ -1598,6 +1598,13 @@ class CollectionSearchValuesResponse(BaseModel):
names: List[str]
# ============================================================================
class CollectionAllResponse(BaseModel):
"""Response model for '$all' collection endpoint"""
resources: List[CrawlFileOut] = []
# ============================================================================
### ORGS ###

View File

@ -191,6 +191,7 @@ class PageOps:
async def _add_pages_to_db(self, crawl_id: str, pages: List[Page], ordered=True):
"""Add batch of pages to db in one insert"""
try:
result = await self.pages.insert_many(
[
page.to_dict(
@ -200,6 +201,12 @@ class PageOps:
],
ordered=ordered,
)
except pymongo.errors.BulkWriteError as bwe:
for err in bwe.details.get("writeErrors", []):
# ignorable duplicate key errors
if err.get("code") != 11000:
raise
if not result.inserted_ids:
# pylint: disable=broad-exception-raised
raise Exception("No pages inserted")

View File

@ -34,6 +34,7 @@ from aiobotocore.config import AioConfig
import aiobotocore.session
import requests
import pymongo
from types_aiobotocore_s3 import S3Client as AIOS3Client
from types_aiobotocore_s3.type_defs import CompletedPartTypeDef
@ -71,6 +72,7 @@ CHUNK_SIZE = 1024 * 256
# ============================================================================
# pylint: disable=broad-except,raise-missing-from,too-many-instance-attributes
# pylint: disable=too-many-public-methods
class StorageOps:
"""All storage handling, download/upload operations"""
@ -105,6 +107,7 @@ class StorageOps:
self.frontend_origin = f"{frontend_origin}.{default_namespace}"
self.local_minio_access_path = os.environ.get("LOCAL_MINIO_ACCESS_PATH")
self.presign_batch_size = int(os.environ.get("PRESIGN_BATCH_SIZE", 8))
with open(os.environ["STORAGES_JSON"], encoding="utf-8") as fh:
storage_list = json.loads(fh.read())
@ -146,6 +149,18 @@ class StorageOps:
async def init_index(self):
"""init index for storages"""
try:
await self.presigned_urls.create_index(
"signedAt", expireAfterSeconds=self.expire_at_duration_seconds
)
except pymongo.errors.OperationFailure:
# create_index() fails if expire_at_duration_seconds has changed since
# previous run
# if so, just delete this index (as this collection is temporary anyway)
# and recreate
print("Recreating presigned_urls index")
await self.presigned_urls.drop_indexes()
await self.presigned_urls.create_index(
"signedAt", expireAfterSeconds=self.expire_at_duration_seconds
)
@ -299,11 +314,10 @@ class StorageOps:
session = aiobotocore.session.get_session()
s3 = None
config = None
if for_presign and storage.access_endpoint_url != storage.endpoint_url:
s3 = {"addressing_style": storage.access_addressing_style}
config = AioConfig(signature_version="s3v4", s3=s3)
async with session.create_client(
@ -496,26 +510,15 @@ class StorageOps:
s3storage,
for_presign=True,
) as (client, bucket, key):
orig_key = key
key += crawlfile.filename
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key},
Params={"Bucket": bucket, "Key": key + crawlfile.filename},
ExpiresIn=PRESIGN_DURATION_SECONDS,
)
if (
s3storage.access_endpoint_url
and s3storage.access_endpoint_url != s3storage.endpoint_url
):
virtual = s3storage.access_addressing_style == "virtual"
parts = urlsplit(s3storage.endpoint_url)
host_endpoint_url = (
f"{parts.scheme}://{bucket}.{parts.netloc}/{orig_key}"
if virtual
else f"{parts.scheme}://{parts.netloc}/{bucket}/{orig_key}"
)
host_endpoint_url = self.get_host_endpoint_url(s3storage, bucket, key)
if host_endpoint_url:
presigned_url = presigned_url.replace(
host_endpoint_url, s3storage.access_endpoint_url
)
@ -535,6 +538,83 @@ class StorageOps:
return presigned_url, now + self.signed_duration_delta
def get_host_endpoint_url(
self, s3storage: S3Storage, bucket: str, key: str
) -> Optional[str]:
"""compute host endpoint for given storage for replacement for access"""
if not s3storage.access_endpoint_url:
return None
if s3storage.access_endpoint_url == s3storage.endpoint_url:
return None
is_virtual = s3storage.access_addressing_style == "virtual"
parts = urlsplit(s3storage.endpoint_url)
host_endpoint_url = (
f"{parts.scheme}://{bucket}.{parts.netloc}/{key}"
if is_virtual
else f"{parts.scheme}://{parts.netloc}/{bucket}/{key}"
)
return host_endpoint_url
async def get_presigned_urls_bulk(
self, org: Organization, s3storage: S3Storage, filenames: list[str]
) -> tuple[list[str], datetime]:
"""generate pre-signed url for crawl file"""
urls = []
futures = []
num_batch = self.presign_batch_size
now = dt_now()
async with self.get_s3_client(
s3storage,
for_presign=True,
) as (client, bucket, key):
for filename in filenames:
futures.append(
client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key + filename},
ExpiresIn=PRESIGN_DURATION_SECONDS,
)
)
host_endpoint_url = self.get_host_endpoint_url(s3storage, bucket, key)
for i in range(0, len(futures), num_batch):
batch = futures[i : i + num_batch]
results = await asyncio.gather(*batch)
presigned_obj = []
for presigned_url, filename in zip(results, filenames[i : i + num_batch]):
if host_endpoint_url:
presigned_url = presigned_url.replace(
host_endpoint_url, s3storage.access_endpoint_url
)
urls.append(presigned_url)
presigned_obj.append(
PresignedUrl(
id=filename, url=presigned_url, signedAt=now, oid=org.id
).to_dict()
)
try:
await self.presigned_urls.insert_many(presigned_obj, ordered=False)
except pymongo.errors.BulkWriteError as bwe:
for err in bwe.details.get("writeErrors", []):
# ignorable duplicate key errors
if err.get("code") != 11000:
raise
return urls, now + self.signed_duration_delta
async def delete_file_object(self, org: Organization, crawlfile: BaseFile) -> bool:
"""delete crawl file from storage."""
return await self._delete_file(org, crawlfile.filename, crawlfile.storage)

View File

@ -94,6 +94,8 @@ data:
REPLICA_DELETION_DELAY_DAYS: "{{ .Values.replica_deletion_delay_days | default 0 }}"
PRESIGN_BATCH_SIZE: "{{ .Values.presign_batch_size | default 8 }}"
---
apiVersion: v1