support usage counters per archive, per user -- handle crawl completion
This commit is contained in:
parent
170958be37
commit
4b08163ead
@ -3,7 +3,7 @@ Archive API handling
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Dict
|
||||
|
||||
@ -57,6 +57,8 @@ class Archive(BaseMongoModel):
|
||||
|
||||
storage: S3Storage
|
||||
|
||||
usage: Dict[str, int] = {}
|
||||
|
||||
def is_owner(self, user):
|
||||
"""Check if user is owner"""
|
||||
return self._is_auth(user, UserRole.OWNER)
|
||||
@ -79,10 +81,13 @@ class Archive(BaseMongoModel):
|
||||
|
||||
def serialize_for_user(self, user: User):
|
||||
"""Serialize based on current user access"""
|
||||
exclude = {}
|
||||
exclude = set()
|
||||
if not self.is_owner(user):
|
||||
exclude = {"users", "storage"}
|
||||
|
||||
if not self.is_crawler(user):
|
||||
exclude.add("usage")
|
||||
|
||||
return self.dict(
|
||||
exclude_unset=True,
|
||||
exclude_defaults=True,
|
||||
@ -215,6 +220,15 @@ class ArchiveOps:
|
||||
await self.update(archive)
|
||||
return True
|
||||
|
||||
async def inc_usage(self, aid, amount):
|
||||
""" Increment usage counter by month for this archive """
|
||||
yymm = datetime.utcnow().strftime("%Y-%m")
|
||||
res = await self.archives.find_one_and_update(
|
||||
{"_id": aid}, {"$inc": {f"usage.{yymm}": amount}}
|
||||
)
|
||||
print(res)
|
||||
return res is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
def init_archives_api(app, mdb, users, email, user_dep: User):
|
||||
@ -264,7 +278,7 @@ def init_archives_api(app, mdb, users, email, user_dep: User):
|
||||
invite_code = uuid.uuid4().hex
|
||||
|
||||
invite_pending = InvitePending(
|
||||
aid=str(archive.id), created=datetime.datetime.utcnow(), role=invite.role
|
||||
aid=str(archive.id), created=datetime.utcnow(), role=invite.role
|
||||
)
|
||||
|
||||
other_user = await users.db.get_by_email(invite.email)
|
||||
|
@ -1,5 +1,7 @@
|
||||
""" Crawl API """
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
@ -9,6 +11,7 @@ from pydantic import BaseModel
|
||||
# ============================================================================
|
||||
class CrawlComplete(BaseModel):
|
||||
""" Store State of Completed Crawls """
|
||||
|
||||
id: str
|
||||
|
||||
user: str
|
||||
@ -19,24 +22,34 @@ class CrawlComplete(BaseModel):
|
||||
size: int
|
||||
hash: str
|
||||
|
||||
created: Optional[datetime]
|
||||
started: Optional[datetime]
|
||||
finished: Optional[datetime]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
def init_crawls_api(app, crawl_manager):
|
||||
def init_crawls_api(app, crawl_manager, users, archives):
|
||||
""" API for crawl management, including crawl done callback"""
|
||||
|
||||
async def on_handle_crawl_complete(msg: CrawlComplete):
|
||||
data = await crawl_manager.validate_crawl_data(msg)
|
||||
if data:
|
||||
print(msg)
|
||||
else:
|
||||
print("Not a valid crawl complete msg!")
|
||||
if not await crawl_manager.validate_crawl_complete(msg):
|
||||
print("Not a valid crawl complete msg!", flush=True)
|
||||
return
|
||||
|
||||
print(msg, flush=True)
|
||||
|
||||
dura = int((msg.finished - msg.started).total_seconds())
|
||||
|
||||
print(f"Duration: {dura}", flush=True)
|
||||
await users.inc_usage(msg.user, dura)
|
||||
await archives.inc_usage(msg.aid, dura)
|
||||
|
||||
@app.post("/crawls/done")
|
||||
async def webhook(msg: CrawlComplete):
|
||||
#background_tasks.add_task(on_handle_crawl_complete, msg)
|
||||
#asyncio.ensure_future(on_handle_crawl_complete(msg))
|
||||
await on_handle_crawl_complete(msg)
|
||||
# background_tasks.add_task(on_handle_crawl_complete, msg)
|
||||
# asyncio.ensure_future(on_handle_crawl_complete(msg))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(on_handle_crawl_complete(msg))
|
||||
|
||||
# await on_handle_crawl_complete(msg)
|
||||
return {"message": "webhook received"}
|
||||
|
@ -1,13 +1,18 @@
|
||||
# pylint: skip-file
|
||||
|
||||
from archives import Archive
|
||||
from crawls import CrawlConfig
|
||||
import asyncio
|
||||
|
||||
|
||||
class DockerManager:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def test():
|
||||
print("test async", flush=True)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(test())
|
||||
print("starting")
|
||||
|
||||
async def add_crawl_config(
|
||||
self,
|
||||
userid: str,
|
||||
|
@ -30,19 +30,35 @@ class K8SManager:
|
||||
self.crawler_image = os.environ.get("CRAWLER_IMAGE")
|
||||
self.crawler_image_pull_policy = "IfNotPresent"
|
||||
|
||||
async def validate_crawl_data(self, crawlcomplete):
|
||||
""" Ensure the crawlcomplete data is valid (pod exists and user matches)
|
||||
Fill in additional details about the crawl """
|
||||
pod = await self.core_api.read_namespaced_pod(name=crawlcomplete.id, namespace=self.namespace)
|
||||
# loop = asyncio.get_running_loop()
|
||||
# loop.create_task(self.watch_job_done())
|
||||
|
||||
if not pod or pod.metadata.labels["btrix.user"] != crawlcomplete.user:
|
||||
async def validate_crawl_complete(self, crawlcomplete):
|
||||
"""Ensure the crawlcomplete data is valid (job exists and user matches)
|
||||
Fill in additional details about the crawl"""
|
||||
job = await self.batch_api.read_namespaced_job(
|
||||
name=crawlcomplete.id, namespace=self.namespace
|
||||
)
|
||||
|
||||
if not job or job.metadata.labels["btrix.user"] != crawlcomplete.user:
|
||||
return False
|
||||
|
||||
crawlcomplete.id = pod.metadata.labels["job-name"]
|
||||
crawlcomplete.created = pod.metadata.creation_timestamp
|
||||
crawlcomplete.aid = pod.metadata.labels["btrix.archive"]
|
||||
crawlcomplete.cid = pod.metadata.labels["btrix.crawlconfig"]
|
||||
crawlcomplete.finished = datetime.datetime.utcnow()
|
||||
# job.metadata.annotations = {
|
||||
# "crawl.size": str(crawlcomplete.size),
|
||||
# "crawl.filename": crawlcomplete.filename,
|
||||
# "crawl.hash": crawlcomplete.hash
|
||||
# }
|
||||
|
||||
# await self.batch_api.patch_namespaced_job(
|
||||
# name=crawlcomplete.id, namespace=self.namespace, body=job
|
||||
# )
|
||||
|
||||
crawlcomplete.started = job.status.start_time.replace(tzinfo=None)
|
||||
crawlcomplete.aid = job.metadata.labels["btrix.archive"]
|
||||
crawlcomplete.cid = job.metadata.labels["btrix.crawlconfig"]
|
||||
crawlcomplete.finished = datetime.datetime.utcnow().replace(
|
||||
microsecond=0, tzinfo=None
|
||||
)
|
||||
return True
|
||||
|
||||
async def add_crawl_config(
|
||||
@ -257,7 +273,9 @@ class K8SManager:
|
||||
{
|
||||
"name": "CRAWL_ID",
|
||||
"valueFrom": {
|
||||
"fieldRef": {"fieldPath": "metadata.name"}
|
||||
"fieldRef": {
|
||||
"fieldPath": "metadata.labels['job-name']"
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
|
@ -72,7 +72,9 @@ class BrowsertrixAPI:
|
||||
self.crawl_manager,
|
||||
)
|
||||
|
||||
init_crawls_api(self.app, self.crawl_manager)
|
||||
init_crawls_api(
|
||||
self.app, self.crawl_manager, self.fastapi_users.db, self.archive_ops
|
||||
)
|
||||
|
||||
self.app.include_router(self.archive_ops.router)
|
||||
|
||||
@ -80,7 +82,6 @@ class BrowsertrixAPI:
|
||||
# async def root():
|
||||
# return {"message": "Hello World"}
|
||||
|
||||
|
||||
# pylint: disable=no-self-use, unused-argument
|
||||
async def on_after_register(self, user: UserDB, request: Request):
|
||||
"""callback after registeration"""
|
||||
@ -123,9 +124,6 @@ class BrowsertrixAPI:
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# app = BrowsertrixAPI().app
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""init on startup"""
|
||||
|
@ -11,7 +11,7 @@ from typing import Dict, Optional
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, UUID4
|
||||
|
||||
from fastapi_users import FastAPIUsers, models
|
||||
from fastapi_users.authentication import JWTAuthentication
|
||||
@ -44,6 +44,8 @@ class User(models.BaseUser):
|
||||
Base User Model
|
||||
"""
|
||||
|
||||
usage: Dict[str, int] = {}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class UserCreate(models.BaseUserCreate):
|
||||
@ -69,6 +71,19 @@ class UserDB(User, models.BaseUserDB):
|
||||
"""
|
||||
|
||||
invites: Dict[str, InvitePending] = {}
|
||||
usage: Dict[str, int] = {}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
class UserDBOps(MongoDBUserDatabase):
|
||||
""" User DB Operations wrapper """
|
||||
|
||||
async def inc_usage(self, userid, amount):
|
||||
""" Increment usage counter by month for this user """
|
||||
yymm = datetime.utcnow().strftime("%Y-%m")
|
||||
await self.collection.find_one_and_update(
|
||||
{"id": UUID4(userid)}, {"$inc": {f"usage.{yymm}": amount}}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@ -85,7 +100,7 @@ def init_users_api(
|
||||
|
||||
user_collection = mdb.get_collection("users")
|
||||
|
||||
user_db = MongoDBUserDatabase(UserDB, user_collection)
|
||||
user_db = UserDBOps(UserDB, user_collection)
|
||||
|
||||
jwt_authentication = JWTAuthentication(
|
||||
secret=PASSWORD_SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login"
|
||||
@ -99,6 +114,7 @@ def init_users_api(
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication),
|
||||
prefix="/auth/jwt",
|
||||
|
Loading…
Reference in New Issue
Block a user