Ensure replay.json returns correct origin for pagesQueryUrl (#2741)
- Use the Host + X-Forwarded-Proto header from API request - Fixes #2740, better fix for #2720 avoiding need for separate alias
This commit is contained in:
parent
0402f14b5e
commit
3af94ca03d
@ -18,7 +18,7 @@ import os
|
||||
import urllib.parse
|
||||
|
||||
import asyncio
|
||||
from fastapi import HTTPException, Depends
|
||||
from fastapi import HTTPException, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import pymongo
|
||||
|
||||
@ -1057,27 +1057,33 @@ def init_base_crawls_api(app, user_dep, *args):
|
||||
tags=["all-crawls"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_base_crawl(crawl_id: str, org: Organization = Depends(org_crawl_dep)):
|
||||
return await ops.get_crawl_out(crawl_id, org)
|
||||
async def get_base_crawl(
|
||||
crawl_id: str, request: Request, org: Organization = Depends(org_crawl_dep)
|
||||
):
|
||||
return await ops.get_crawl_out(crawl_id, org, headers=dict(request.headers))
|
||||
|
||||
@app.get(
|
||||
"/orgs/all/all-crawls/{crawl_id}/replay.json",
|
||||
tags=["all-crawls"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_base_crawl_admin(crawl_id, user: User = Depends(user_dep)):
|
||||
async def get_base_crawl_admin(
|
||||
crawl_id, request: Request, user: User = Depends(user_dep)
|
||||
):
|
||||
if not user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="Not Allowed")
|
||||
|
||||
return await ops.get_crawl_out(crawl_id, None)
|
||||
return await ops.get_crawl_out(crawl_id, None, headers=dict(request.headers))
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/all-crawls/{crawl_id}/replay.json",
|
||||
tags=["all-crawls"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_crawl_out(crawl_id, org: Organization = Depends(org_viewer_dep)):
|
||||
return await ops.get_crawl_out(crawl_id, org)
|
||||
async def get_crawl_out(
|
||||
crawl_id, request: Request, org: Organization = Depends(org_viewer_dep)
|
||||
):
|
||||
return await ops.get_crawl_out(crawl_id, org, headers=dict(request.headers))
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/all-crawls/{crawl_id}/download",
|
||||
|
@ -12,7 +12,7 @@ from uuid import UUID
|
||||
|
||||
from typing import Optional, List, Dict, Union, Any, Sequence, AsyncIterator
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from redis import asyncio as exceptions
|
||||
from redis.asyncio.client import Redis
|
||||
@ -1345,19 +1345,27 @@ def init_crawls_api(crawl_manager: CrawlManager, app, user_dep, *args):
|
||||
tags=["crawls"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_crawl_admin(crawl_id, user: User = Depends(user_dep)):
|
||||
async def get_crawl_admin(
|
||||
crawl_id, request: Request, user: User = Depends(user_dep)
|
||||
):
|
||||
if not user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="Not Allowed")
|
||||
|
||||
return await ops.get_crawl_out(crawl_id, None, "crawl")
|
||||
return await ops.get_crawl_out(
|
||||
crawl_id, None, "crawl", headers=dict(request.headers)
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/crawls/{crawl_id}/replay.json",
|
||||
tags=["crawls"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_crawl_out(crawl_id, org: Organization = Depends(org_viewer_dep)):
|
||||
return await ops.get_crawl_out(crawl_id, org, "crawl")
|
||||
async def get_crawl_out(
|
||||
crawl_id, request: Request, org: Organization = Depends(org_viewer_dep)
|
||||
):
|
||||
return await ops.get_crawl_out(
|
||||
crawl_id, org, "crawl", headers=dict(request.headers)
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/crawls/{crawl_id}/download", tags=["crawls"], response_model=bytes
|
||||
|
@ -367,27 +367,39 @@ def init_uploads_api(app, user_dep, *args):
|
||||
tags=["uploads"],
|
||||
response_model=CrawlOut,
|
||||
)
|
||||
async def get_upload(crawlid: str, org: Organization = Depends(org_crawl_dep)):
|
||||
return await ops.get_crawl_out(crawlid, org, "upload")
|
||||
async def get_upload(
|
||||
crawlid: str, request: Request, org: Organization = Depends(org_crawl_dep)
|
||||
):
|
||||
return await ops.get_crawl_out(
|
||||
crawlid, org, "upload", headers=dict(request.headers)
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/orgs/all/uploads/{crawl_id}/replay.json",
|
||||
tags=["uploads"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_upload_replay_admin(crawl_id, user: User = Depends(user_dep)):
|
||||
async def get_upload_replay_admin(
|
||||
crawl_id, request: Request, user: User = Depends(user_dep)
|
||||
):
|
||||
if not user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="Not Allowed")
|
||||
|
||||
return await ops.get_crawl_out(crawl_id, None, "upload")
|
||||
return await ops.get_crawl_out(
|
||||
crawl_id, None, "upload", headers=dict(request.headers)
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/uploads/{crawl_id}/replay.json",
|
||||
tags=["uploads"],
|
||||
response_model=CrawlOutWithResources,
|
||||
)
|
||||
async def get_upload_replay(crawl_id, org: Organization = Depends(org_viewer_dep)):
|
||||
return await ops.get_crawl_out(crawl_id, org, "upload")
|
||||
async def get_upload_replay(
|
||||
crawl_id, request: Request, org: Organization = Depends(org_viewer_dep)
|
||||
):
|
||||
return await ops.get_crawl_out(
|
||||
crawl_id, org, "upload", headers=dict(request.headers)
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/orgs/{oid}/uploads/{crawl_id}/download",
|
||||
|
@ -181,8 +181,8 @@ def get_origin(headers) -> str:
|
||||
if not headers:
|
||||
return default_origin
|
||||
|
||||
scheme = headers.get("X-Forwarded-Proto")
|
||||
host = headers.get("Host")
|
||||
scheme = headers.get("x-forwarded-proto")
|
||||
host = headers.get("host")
|
||||
if not scheme or not host:
|
||||
return default_origin
|
||||
|
||||
|
@ -401,7 +401,7 @@ def test_get_collection(crawler_auth_headers, default_org_id):
|
||||
def test_get_collection_replay(crawler_auth_headers, default_org_id):
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}/replay.json",
|
||||
headers=crawler_auth_headers,
|
||||
headers={"host": "custom-domain.example.com", **crawler_auth_headers},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
@ -421,8 +421,9 @@ def test_get_collection_replay(crawler_auth_headers, default_org_id):
|
||||
assert data["dateLatest"]
|
||||
assert data["defaultThumbnailName"]
|
||||
assert data["initialPages"]
|
||||
assert data["pagesQueryUrl"].endswith(
|
||||
f"/orgs/{default_org_id}/collections/{_coll_id}/pages"
|
||||
assert (
|
||||
data["pagesQueryUrl"]
|
||||
== f"http://custom-domain.example.com/api/orgs/{default_org_id}/collections/{_coll_id}/pages"
|
||||
)
|
||||
assert data["downloadUrl"] is None
|
||||
assert "preloadResources" in data
|
||||
@ -455,12 +456,13 @@ def test_collection_public(crawler_auth_headers, default_org_id):
|
||||
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/collections/{_coll_id}/public/replay.json",
|
||||
headers=crawler_auth_headers,
|
||||
headers={"host": "custom-domain.example.com", **crawler_auth_headers},
|
||||
)
|
||||
data = r.json()
|
||||
assert data["initialPages"]
|
||||
assert data["pagesQueryUrl"].endswith(
|
||||
f"/orgs/{default_org_id}/collections/{_coll_id}/public/pages"
|
||||
assert (
|
||||
data["pagesQueryUrl"]
|
||||
== f"http://custom-domain.example.com/api/orgs/{default_org_id}/collections/{_coll_id}/public/pages"
|
||||
)
|
||||
assert data["downloadUrl"] is not None
|
||||
assert "preloadResources" in data
|
||||
|
@ -176,3 +176,12 @@ def test_stop_crawl_partial(
|
||||
assert data["stopping"] == True
|
||||
|
||||
assert len(data["resources"]) == 1
|
||||
|
||||
|
||||
def test_crawl_with_hostname(default_org_id, crawler_auth_headers):
|
||||
r = requests.get(
|
||||
f"{API_PREFIX}/orgs/{default_org_id}/crawls/{crawl_id}/replay.json",
|
||||
headers={"X-Forwarded-Proto": "https", "host": "custom-domain.example.com", **crawler_auth_headers},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["pagesQueryUrl"].startswith("https://custom-domain.example.com/")
|
||||
|
@ -170,8 +170,6 @@ def test_delete_org_crawl_running(
|
||||
except:
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
|
||||
attempts += 1
|
||||
|
||||
# Check that org was deleted
|
||||
|
Loading…
Reference in New Issue
Block a user