Ensure replay.json returns correct origin for pagesQueryUrl (#2741)

- Use the Host + X-Forwarded-Proto header from API request
- Fixes #2740, better fix for #2720 avoiding need for separate alias
This commit is contained in:
Ilya Kreymer 2025-07-16 10:48:24 -07:00 committed by GitHub
parent 0402f14b5e
commit 3af94ca03d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 28 deletions

View File

@ -18,7 +18,7 @@ import os
import urllib.parse import urllib.parse
import asyncio import asyncio
from fastapi import HTTPException, Depends from fastapi import HTTPException, Depends, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import pymongo import pymongo
@ -1057,27 +1057,33 @@ def init_base_crawls_api(app, user_dep, *args):
tags=["all-crawls"], tags=["all-crawls"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_base_crawl(crawl_id: str, org: Organization = Depends(org_crawl_dep)): async def get_base_crawl(
return await ops.get_crawl_out(crawl_id, org) crawl_id: str, request: Request, org: Organization = Depends(org_crawl_dep)
):
return await ops.get_crawl_out(crawl_id, org, headers=dict(request.headers))
@app.get( @app.get(
"/orgs/all/all-crawls/{crawl_id}/replay.json", "/orgs/all/all-crawls/{crawl_id}/replay.json",
tags=["all-crawls"], tags=["all-crawls"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_base_crawl_admin(crawl_id, user: User = Depends(user_dep)): async def get_base_crawl_admin(
crawl_id, request: Request, user: User = Depends(user_dep)
):
if not user.is_superuser: if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed") raise HTTPException(status_code=403, detail="Not Allowed")
return await ops.get_crawl_out(crawl_id, None) return await ops.get_crawl_out(crawl_id, None, headers=dict(request.headers))
@app.get( @app.get(
"/orgs/{oid}/all-crawls/{crawl_id}/replay.json", "/orgs/{oid}/all-crawls/{crawl_id}/replay.json",
tags=["all-crawls"], tags=["all-crawls"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_crawl_out(crawl_id, org: Organization = Depends(org_viewer_dep)): async def get_crawl_out(
return await ops.get_crawl_out(crawl_id, org) crawl_id, request: Request, org: Organization = Depends(org_viewer_dep)
):
return await ops.get_crawl_out(crawl_id, org, headers=dict(request.headers))
@app.get( @app.get(
"/orgs/{oid}/all-crawls/{crawl_id}/download", "/orgs/{oid}/all-crawls/{crawl_id}/download",

View File

@ -12,7 +12,7 @@ from uuid import UUID
from typing import Optional, List, Dict, Union, Any, Sequence, AsyncIterator from typing import Optional, List, Dict, Union, Any, Sequence, AsyncIterator
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from redis import asyncio as exceptions from redis import asyncio as exceptions
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
@ -1345,19 +1345,27 @@ def init_crawls_api(crawl_manager: CrawlManager, app, user_dep, *args):
tags=["crawls"], tags=["crawls"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_crawl_admin(crawl_id, user: User = Depends(user_dep)): async def get_crawl_admin(
crawl_id, request: Request, user: User = Depends(user_dep)
):
if not user.is_superuser: if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed") raise HTTPException(status_code=403, detail="Not Allowed")
return await ops.get_crawl_out(crawl_id, None, "crawl") return await ops.get_crawl_out(
crawl_id, None, "crawl", headers=dict(request.headers)
)
@app.get( @app.get(
"/orgs/{oid}/crawls/{crawl_id}/replay.json", "/orgs/{oid}/crawls/{crawl_id}/replay.json",
tags=["crawls"], tags=["crawls"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_crawl_out(crawl_id, org: Organization = Depends(org_viewer_dep)): async def get_crawl_out(
return await ops.get_crawl_out(crawl_id, org, "crawl") crawl_id, request: Request, org: Organization = Depends(org_viewer_dep)
):
return await ops.get_crawl_out(
crawl_id, org, "crawl", headers=dict(request.headers)
)
@app.get( @app.get(
"/orgs/{oid}/crawls/{crawl_id}/download", tags=["crawls"], response_model=bytes "/orgs/{oid}/crawls/{crawl_id}/download", tags=["crawls"], response_model=bytes

View File

@ -367,27 +367,39 @@ def init_uploads_api(app, user_dep, *args):
tags=["uploads"], tags=["uploads"],
response_model=CrawlOut, response_model=CrawlOut,
) )
async def get_upload(crawlid: str, org: Organization = Depends(org_crawl_dep)): async def get_upload(
return await ops.get_crawl_out(crawlid, org, "upload") crawlid: str, request: Request, org: Organization = Depends(org_crawl_dep)
):
return await ops.get_crawl_out(
crawlid, org, "upload", headers=dict(request.headers)
)
@app.get( @app.get(
"/orgs/all/uploads/{crawl_id}/replay.json", "/orgs/all/uploads/{crawl_id}/replay.json",
tags=["uploads"], tags=["uploads"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_upload_replay_admin(crawl_id, user: User = Depends(user_dep)): async def get_upload_replay_admin(
crawl_id, request: Request, user: User = Depends(user_dep)
):
if not user.is_superuser: if not user.is_superuser:
raise HTTPException(status_code=403, detail="Not Allowed") raise HTTPException(status_code=403, detail="Not Allowed")
return await ops.get_crawl_out(crawl_id, None, "upload") return await ops.get_crawl_out(
crawl_id, None, "upload", headers=dict(request.headers)
)
@app.get( @app.get(
"/orgs/{oid}/uploads/{crawl_id}/replay.json", "/orgs/{oid}/uploads/{crawl_id}/replay.json",
tags=["uploads"], tags=["uploads"],
response_model=CrawlOutWithResources, response_model=CrawlOutWithResources,
) )
async def get_upload_replay(crawl_id, org: Organization = Depends(org_viewer_dep)): async def get_upload_replay(
return await ops.get_crawl_out(crawl_id, org, "upload") crawl_id, request: Request, org: Organization = Depends(org_viewer_dep)
):
return await ops.get_crawl_out(
crawl_id, org, "upload", headers=dict(request.headers)
)
@app.get( @app.get(
"/orgs/{oid}/uploads/{crawl_id}/download", "/orgs/{oid}/uploads/{crawl_id}/download",

View File

@ -181,8 +181,8 @@ def get_origin(headers) -> str:
if not headers: if not headers:
return default_origin return default_origin
scheme = headers.get("X-Forwarded-Proto") scheme = headers.get("x-forwarded-proto")
host = headers.get("Host") host = headers.get("host")
if not scheme or not host: if not scheme or not host:
return default_origin return default_origin

View File

@ -401,7 +401,7 @@ def test_get_collection(crawler_auth_headers, default_org_id):
def test_get_collection_replay(crawler_auth_headers, default_org_id): def test_get_collection_replay(crawler_auth_headers, default_org_id):
r = requests.get( r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}/replay.json", f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}/replay.json",
headers=crawler_auth_headers, headers={"host": "custom-domain.example.com", **crawler_auth_headers},
) )
assert r.status_code == 200 assert r.status_code == 200
data = r.json() data = r.json()
@ -421,8 +421,9 @@ def test_get_collection_replay(crawler_auth_headers, default_org_id):
assert data["dateLatest"] assert data["dateLatest"]
assert data["defaultThumbnailName"] assert data["defaultThumbnailName"]
assert data["initialPages"] assert data["initialPages"]
assert data["pagesQueryUrl"].endswith( assert (
f"/orgs/{default_org_id}/collections/{_coll_id}/pages" data["pagesQueryUrl"]
== f"http://custom-domain.example.com/api/orgs/{default_org_id}/collections/{_coll_id}/pages"
) )
assert data["downloadUrl"] is None assert data["downloadUrl"] is None
assert "preloadResources" in data assert "preloadResources" in data
@ -455,12 +456,13 @@ def test_collection_public(crawler_auth_headers, default_org_id):
r = requests.get( r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}/public/replay.json", f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}/public/replay.json",
headers=crawler_auth_headers, headers={"host": "custom-domain.example.com", **crawler_auth_headers},
) )
data = r.json() data = r.json()
assert data["initialPages"] assert data["initialPages"]
assert data["pagesQueryUrl"].endswith( assert (
f"/orgs/{default_org_id}/collections/{_coll_id}/public/pages" data["pagesQueryUrl"]
== f"http://custom-domain.example.com/api/orgs/{default_org_id}/collections/{_coll_id}/public/pages"
) )
assert data["downloadUrl"] is not None assert data["downloadUrl"] is not None
assert "preloadResources" in data assert "preloadResources" in data

View File

@ -176,3 +176,12 @@ def test_stop_crawl_partial(
assert data["stopping"] == True assert data["stopping"] == True
assert len(data["resources"]) == 1 assert len(data["resources"]) == 1
def test_crawl_with_hostname(default_org_id, crawler_auth_headers):
r = requests.get(
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawl_id}/replay.json",
headers={"X-Forwarded-Proto": "https", "host": "custom-domain.example.com", **crawler_auth_headers},
)
assert r.status_code == 200
assert r.json()["pagesQueryUrl"].startswith("https://custom-domain.example.com/")

View File

@ -170,8 +170,6 @@ def test_delete_org_crawl_running(
except: except:
time.sleep(10) time.sleep(10)
attempts += 1 attempts += 1
# Check that org was deleted # Check that org was deleted