Optimize presigning for replay.json (#2516)

Fixes #2515.

This PR introduces a significantly optimized logic for presigning URLs
for crawls and collections.
- For collections, the files needed from all crawls are looked up, and
then the 'presign_urls' table is merged in one pass, resulting in a
unified iterator containing files and presign urls for those files.
- For crawls, the presign URLs are also looked up once, and the same
iterator is used for a single crawl with passed in list of CrawlFiles
- URLs that are already signed are added to the return list.
- For any remaining URLs to be signed, a bulk presigning function is
added, which shares an HTTP connection and signing 8 files in parallels
(customizable via helm chart, though may not be needed). This function
is used to call the presigning API in parallel.
This commit is contained in:
Ilya Kreymer 2025-05-20 12:09:35 -07:00 committed by GitHub
parent f1fd11c031
commit c134b576ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 283 additions and 88 deletions

View File

@ -1,7 +1,18 @@
"""base crawl type""" """base crawl type"""
from datetime import datetime from datetime import datetime
from typing import Optional, List, Union, Dict, Any, Type, TYPE_CHECKING, cast, Tuple from typing import (
Optional,
List,
Union,
Dict,
Any,
Type,
TYPE_CHECKING,
cast,
Tuple,
AsyncIterable,
)
from uuid import UUID from uuid import UUID
import os import os
import urllib.parse import urllib.parse
@ -76,6 +87,7 @@ class BaseCrawlOps:
background_job_ops: BackgroundJobOps, background_job_ops: BackgroundJobOps,
): ):
self.crawls = mdb["crawls"] self.crawls = mdb["crawls"]
self.presigned_urls = mdb["presigned_urls"]
self.crawl_configs = crawl_configs self.crawl_configs = crawl_configs
self.user_manager = users self.user_manager = users
self.orgs = orgs self.orgs = orgs
@ -464,29 +476,130 @@ class BaseCrawlOps:
) -> List[CrawlFileOut]: ) -> List[CrawlFileOut]:
"""Regenerate presigned URLs for files as necessary""" """Regenerate presigned URLs for files as necessary"""
if not files: if not files:
print("no files")
return [] return []
out_files = [] out_files = []
for file_ in files: cursor = self.presigned_urls.find(
presigned_url, expire_at = await self.storage_ops.get_presigned_url( {"_id": {"$in": [file.filename for file in files]}}
org, file_, force_update=force_update
) )
out_files.append( presigned = await cursor.to_list(10000)
files_dict = [file.dict() for file in files]
# need an async generator to call bulk_presigned_files
async def async_gen():
yield {"presigned": presigned, "files": files_dict, "_id": crawl_id}
out_files, _ = await self.bulk_presigned_files(async_gen(), org, force_update)
return out_files
async def get_presigned_files(
self, match: dict[str, Any], org: Organization
) -> tuple[list[CrawlFileOut], bool]:
"""return presigned crawl files queried as batch, merging presigns with files in one pass"""
cursor = self.crawls.aggregate(
[
{"$match": match},
{"$project": {"files": "$files", "version": 1}},
{
"$lookup": {
"from": "presigned_urls",
"localField": "files.filename",
"foreignField": "_id",
"as": "presigned",
}
},
]
)
return await self.bulk_presigned_files(cursor, org)
async def bulk_presigned_files(
self,
cursor: AsyncIterable[dict[str, Any]],
org: Organization,
force_update=False,
) -> tuple[list[CrawlFileOut], bool]:
"""process presigned files in batches"""
resources = []
pages_optimized = False
sign_files = []
async for result in cursor:
pages_optimized = result.get("version") == 2
mapping = {}
# create mapping of filename -> file data
for file in result["files"]:
file["crawl_id"] = result["_id"]
mapping[file["filename"]] = file
if not force_update:
# add already presigned resources
for presigned in result["presigned"]:
file = mapping.get(presigned["_id"])
if file:
file["signedAt"] = presigned["signedAt"]
file["path"] = presigned["url"]
resources.append(
CrawlFileOut( CrawlFileOut(
name=os.path.basename(file_.filename), name=os.path.basename(file["filename"]),
path=presigned_url or "", path=presigned["url"],
hash=file_.hash, hash=file["hash"],
size=file_.size, size=file["size"],
crawlId=crawl_id, crawlId=file["crawl_id"],
numReplicas=len(file_.replicas) if file_.replicas else 0, numReplicas=len(file.get("replicas") or []),
expireAt=date_to_str(
presigned["signedAt"]
+ self.storage_ops.signed_duration_delta
),
)
)
del mapping[presigned["_id"]]
sign_files.extend(list(mapping.values()))
by_storage: dict[str, dict] = {}
for file in sign_files:
storage_ref = StorageRef(**file.get("storage"))
sid = str(storage_ref)
storage_group = by_storage.get(sid)
if not storage_group:
storage_group = {"ref": storage_ref, "names": [], "files": []}
by_storage[sid] = storage_group
storage_group["names"].append(file["filename"])
storage_group["files"].append(file)
for storage_group in by_storage.values():
s3storage = self.storage_ops.get_org_storage_by_ref(
org, storage_group["ref"]
)
signed_urls, expire_at = await self.storage_ops.get_presigned_urls_bulk(
org, s3storage, storage_group["names"]
)
for url, file in zip(signed_urls, storage_group["files"]):
resources.append(
CrawlFileOut(
name=os.path.basename(file["filename"]),
path=url,
hash=file["hash"],
size=file["size"],
crawlId=file["crawl_id"],
numReplicas=len(file.get("replicas") or []),
expireAt=date_to_str(expire_at), expireAt=date_to_str(expire_at),
) )
) )
return out_files return resources, pages_optimized
async def add_to_collection( async def add_to_collection(
self, crawl_ids: List[str], collection_id: UUID, org: Organization self, crawl_ids: List[str], collection_id: UUID, org: Organization

View File

@ -28,7 +28,6 @@ from .models import (
UpdateColl, UpdateColl,
AddRemoveCrawlList, AddRemoveCrawlList,
BaseCrawl, BaseCrawl,
CrawlOutWithResources,
CrawlFileOut, CrawlFileOut,
Organization, Organization,
PaginatedCollOutResponse, PaginatedCollOutResponse,
@ -40,6 +39,7 @@ from .models import (
AddedResponse, AddedResponse,
DeletedResponse, DeletedResponse,
CollectionSearchValuesResponse, CollectionSearchValuesResponse,
CollectionAllResponse,
OrgPublicCollections, OrgPublicCollections,
PublicOrgDetails, PublicOrgDetails,
CollAccessType, CollAccessType,
@ -50,7 +50,12 @@ from .models import (
MIN_UPLOAD_PART_SIZE, MIN_UPLOAD_PART_SIZE,
PublicCollOut, PublicCollOut,
) )
from .utils import dt_now, slug_from_name, get_duplicate_key_error_field, get_origin from .utils import (
dt_now,
slug_from_name,
get_duplicate_key_error_field,
get_origin,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .orgs import OrgOps from .orgs import OrgOps
@ -346,7 +351,7 @@ class CollectionOps:
result["resources"], result["resources"],
crawl_ids, crawl_ids,
pages_optimized, pages_optimized,
) = await self.get_collection_crawl_resources(coll_id) ) = await self.get_collection_crawl_resources(coll_id, org)
initial_pages, _ = await self.page_ops.list_pages( initial_pages, _ = await self.page_ops.list_pages(
crawl_ids=crawl_ids, crawl_ids=crawl_ids,
@ -400,7 +405,9 @@ class CollectionOps:
if result.get("access") not in allowed_access: if result.get("access") not in allowed_access:
raise HTTPException(status_code=404, detail="collection_not_found") raise HTTPException(status_code=404, detail="collection_not_found")
result["resources"], _, _ = await self.get_collection_crawl_resources(coll_id) result["resources"], _, _ = await self.get_collection_crawl_resources(
coll_id, org
)
thumbnail = result.get("thumbnail") thumbnail = result.get("thumbnail")
if thumbnail: if thumbnail:
@ -554,31 +561,23 @@ class CollectionOps:
return collections, total return collections, total
# pylint: disable=too-many-locals
async def get_collection_crawl_resources( async def get_collection_crawl_resources(
self, coll_id: UUID self, coll_id: Optional[UUID], org: Organization
) -> tuple[List[CrawlFileOut], List[str], bool]: ) -> tuple[List[CrawlFileOut], List[str], bool]:
"""Return pre-signed resources for all collection crawl files.""" """Return pre-signed resources for all collection crawl files."""
# Ensure collection exists match: dict[str, Any]
_ = await self.get_collection_raw(coll_id)
resources = []
pages_optimized = True
crawls, _ = await self.crawl_ops.list_all_base_crawls(
collection_id=coll_id,
states=list(SUCCESSFUL_STATES),
page_size=10_000,
cls_type=CrawlOutWithResources,
)
if coll_id:
crawl_ids = await self.get_collection_crawl_ids(coll_id)
match = {"_id": {"$in": crawl_ids}}
else:
crawl_ids = [] crawl_ids = []
match = {"oid": org.id}
for crawl in crawls: resources, pages_optimized = await self.crawl_ops.get_presigned_files(
crawl_ids.append(crawl.id) match, org
if crawl.resources: )
resources.extend(crawl.resources)
if crawl.version != 2:
pages_optimized = False
return resources, crawl_ids, pages_optimized return resources, crawl_ids, pages_optimized
@ -1009,24 +1008,11 @@ def init_collections_api(
@app.get( @app.get(
"/orgs/{oid}/collections/$all", "/orgs/{oid}/collections/$all",
tags=["collections"], tags=["collections"],
response_model=Dict[str, List[CrawlFileOut]], response_model=CollectionAllResponse,
) )
async def get_collection_all(org: Organization = Depends(org_viewer_dep)): async def get_collection_all(org: Organization = Depends(org_viewer_dep)):
results = {} results = {}
try: results["resources"] = await colls.get_collection_crawl_resources(None, org)
all_collections, _ = await colls.list_collections(org, page_size=10_000)
for collection in all_collections:
(
results[collection.name],
_,
_,
) = await colls.get_collection_crawl_resources(collection.id)
except Exception as exc:
# pylint: disable=raise-missing-from
raise HTTPException(
status_code=400, detail="Error Listing All Crawled Files: " + str(exc)
)
return results return results
@app.get( @app.get(

View File

@ -1598,6 +1598,13 @@ class CollectionSearchValuesResponse(BaseModel):
names: List[str] names: List[str]
# ============================================================================
class CollectionAllResponse(BaseModel):
"""Response model for '$all' collection endpoint"""
resources: List[CrawlFileOut] = []
# ============================================================================ # ============================================================================
### ORGS ### ### ORGS ###

View File

@ -191,6 +191,7 @@ class PageOps:
async def _add_pages_to_db(self, crawl_id: str, pages: List[Page], ordered=True): async def _add_pages_to_db(self, crawl_id: str, pages: List[Page], ordered=True):
"""Add batch of pages to db in one insert""" """Add batch of pages to db in one insert"""
try:
result = await self.pages.insert_many( result = await self.pages.insert_many(
[ [
page.to_dict( page.to_dict(
@ -200,6 +201,12 @@ class PageOps:
], ],
ordered=ordered, ordered=ordered,
) )
except pymongo.errors.BulkWriteError as bwe:
for err in bwe.details.get("writeErrors", []):
# ignorable duplicate key errors
if err.get("code") != 11000:
raise
if not result.inserted_ids: if not result.inserted_ids:
# pylint: disable=broad-exception-raised # pylint: disable=broad-exception-raised
raise Exception("No pages inserted") raise Exception("No pages inserted")

View File

@ -34,6 +34,7 @@ from aiobotocore.config import AioConfig
import aiobotocore.session import aiobotocore.session
import requests import requests
import pymongo
from types_aiobotocore_s3 import S3Client as AIOS3Client from types_aiobotocore_s3 import S3Client as AIOS3Client
from types_aiobotocore_s3.type_defs import CompletedPartTypeDef from types_aiobotocore_s3.type_defs import CompletedPartTypeDef
@ -71,6 +72,7 @@ CHUNK_SIZE = 1024 * 256
# ============================================================================ # ============================================================================
# pylint: disable=broad-except,raise-missing-from,too-many-instance-attributes # pylint: disable=broad-except,raise-missing-from,too-many-instance-attributes
# pylint: disable=too-many-public-methods
class StorageOps: class StorageOps:
"""All storage handling, download/upload operations""" """All storage handling, download/upload operations"""
@ -105,6 +107,7 @@ class StorageOps:
self.frontend_origin = f"{frontend_origin}.{default_namespace}" self.frontend_origin = f"{frontend_origin}.{default_namespace}"
self.local_minio_access_path = os.environ.get("LOCAL_MINIO_ACCESS_PATH") self.local_minio_access_path = os.environ.get("LOCAL_MINIO_ACCESS_PATH")
self.presign_batch_size = int(os.environ.get("PRESIGN_BATCH_SIZE", 8))
with open(os.environ["STORAGES_JSON"], encoding="utf-8") as fh: with open(os.environ["STORAGES_JSON"], encoding="utf-8") as fh:
storage_list = json.loads(fh.read()) storage_list = json.loads(fh.read())
@ -146,6 +149,18 @@ class StorageOps:
async def init_index(self): async def init_index(self):
"""init index for storages""" """init index for storages"""
try:
await self.presigned_urls.create_index(
"signedAt", expireAfterSeconds=self.expire_at_duration_seconds
)
except pymongo.errors.OperationFailure:
# create_index() fails if expire_at_duration_seconds has changed since
# previous run
# if so, just delete this index (as this collection is temporary anyway)
# and recreate
print("Recreating presigned_urls index")
await self.presigned_urls.drop_indexes()
await self.presigned_urls.create_index( await self.presigned_urls.create_index(
"signedAt", expireAfterSeconds=self.expire_at_duration_seconds "signedAt", expireAfterSeconds=self.expire_at_duration_seconds
) )
@ -299,11 +314,10 @@ class StorageOps:
session = aiobotocore.session.get_session() session = aiobotocore.session.get_session()
s3 = None config = None
if for_presign and storage.access_endpoint_url != storage.endpoint_url: if for_presign and storage.access_endpoint_url != storage.endpoint_url:
s3 = {"addressing_style": storage.access_addressing_style} s3 = {"addressing_style": storage.access_addressing_style}
config = AioConfig(signature_version="s3v4", s3=s3) config = AioConfig(signature_version="s3v4", s3=s3)
async with session.create_client( async with session.create_client(
@ -496,26 +510,15 @@ class StorageOps:
s3storage, s3storage,
for_presign=True, for_presign=True,
) as (client, bucket, key): ) as (client, bucket, key):
orig_key = key
key += crawlfile.filename
presigned_url = await client.generate_presigned_url( presigned_url = await client.generate_presigned_url(
"get_object", "get_object",
Params={"Bucket": bucket, "Key": key}, Params={"Bucket": bucket, "Key": key + crawlfile.filename},
ExpiresIn=PRESIGN_DURATION_SECONDS, ExpiresIn=PRESIGN_DURATION_SECONDS,
) )
if ( host_endpoint_url = self.get_host_endpoint_url(s3storage, bucket, key)
s3storage.access_endpoint_url
and s3storage.access_endpoint_url != s3storage.endpoint_url if host_endpoint_url:
):
virtual = s3storage.access_addressing_style == "virtual"
parts = urlsplit(s3storage.endpoint_url)
host_endpoint_url = (
f"{parts.scheme}://{bucket}.{parts.netloc}/{orig_key}"
if virtual
else f"{parts.scheme}://{parts.netloc}/{bucket}/{orig_key}"
)
presigned_url = presigned_url.replace( presigned_url = presigned_url.replace(
host_endpoint_url, s3storage.access_endpoint_url host_endpoint_url, s3storage.access_endpoint_url
) )
@ -535,6 +538,83 @@ class StorageOps:
return presigned_url, now + self.signed_duration_delta return presigned_url, now + self.signed_duration_delta
def get_host_endpoint_url(
self, s3storage: S3Storage, bucket: str, key: str
) -> Optional[str]:
"""compute host endpoint for given storage for replacement for access"""
if not s3storage.access_endpoint_url:
return None
if s3storage.access_endpoint_url == s3storage.endpoint_url:
return None
is_virtual = s3storage.access_addressing_style == "virtual"
parts = urlsplit(s3storage.endpoint_url)
host_endpoint_url = (
f"{parts.scheme}://{bucket}.{parts.netloc}/{key}"
if is_virtual
else f"{parts.scheme}://{parts.netloc}/{bucket}/{key}"
)
return host_endpoint_url
async def get_presigned_urls_bulk(
self, org: Organization, s3storage: S3Storage, filenames: list[str]
) -> tuple[list[str], datetime]:
"""generate pre-signed url for crawl file"""
urls = []
futures = []
num_batch = self.presign_batch_size
now = dt_now()
async with self.get_s3_client(
s3storage,
for_presign=True,
) as (client, bucket, key):
for filename in filenames:
futures.append(
client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key + filename},
ExpiresIn=PRESIGN_DURATION_SECONDS,
)
)
host_endpoint_url = self.get_host_endpoint_url(s3storage, bucket, key)
for i in range(0, len(futures), num_batch):
batch = futures[i : i + num_batch]
results = await asyncio.gather(*batch)
presigned_obj = []
for presigned_url, filename in zip(results, filenames[i : i + num_batch]):
if host_endpoint_url:
presigned_url = presigned_url.replace(
host_endpoint_url, s3storage.access_endpoint_url
)
urls.append(presigned_url)
presigned_obj.append(
PresignedUrl(
id=filename, url=presigned_url, signedAt=now, oid=org.id
).to_dict()
)
try:
await self.presigned_urls.insert_many(presigned_obj, ordered=False)
except pymongo.errors.BulkWriteError as bwe:
for err in bwe.details.get("writeErrors", []):
# ignorable duplicate key errors
if err.get("code") != 11000:
raise
return urls, now + self.signed_duration_delta
async def delete_file_object(self, org: Organization, crawlfile: BaseFile) -> bool: async def delete_file_object(self, org: Organization, crawlfile: BaseFile) -> bool:
"""delete crawl file from storage.""" """delete crawl file from storage."""
return await self._delete_file(org, crawlfile.filename, crawlfile.storage) return await self._delete_file(org, crawlfile.filename, crawlfile.storage)

View File

@ -94,6 +94,8 @@ data:
REPLICA_DELETION_DELAY_DAYS: "{{ .Values.replica_deletion_delay_days | default 0 }}" REPLICA_DELETION_DELAY_DAYS: "{{ .Values.replica_deletion_delay_days | default 0 }}"
PRESIGN_BATCH_SIZE: "{{ .Values.presign_batch_size | default 8 }}"
--- ---
apiVersion: v1 apiVersion: v1