Add API endpoint to import subscription for existing org (#1930)

Fixes #1926 

- adds /subscriptions/import endpoint for importing an existing subscription to an existing org
- add SubscriptionImport object and log as 'import' event in subscription events collection

---------
Co-authored-by: Ilya Kreymer <ikreymer@gmail.com>
This commit is contained in:
Tessa Walsh 2024-07-16 19:17:02 -04:00 committed by GitHub
parent 224b011070
commit 60afb19472
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 142 additions and 9 deletions

View File

@ -1128,6 +1128,23 @@ class SubscriptionCreateOut(SubscriptionCreate, SubscriptionEventOut):
type: Literal["create"] = "create"
# ============================================================================
class SubscriptionImport(BaseModel):
"""import subscription to existing org"""
subId: str
status: str
planId: str
oid: UUID
# ============================================================================
class SubscriptionImportOut(SubscriptionImport, SubscriptionEventOut):
"""Output model for subscription import event"""
type: Literal["import"] = "import"
# ============================================================================
class SubscriptionUpdate(BaseModel):
"""update subscription data"""
@ -2190,7 +2207,12 @@ class PaginatedSubscriptionEventResponse(PaginatedResponse):
"""Response model for paginated subscription events"""
items: List[
Union[SubscriptionCreateOut, SubscriptionUpdateOut, SubscriptionCancelOut]
Union[
SubscriptionCreateOut,
SubscriptionUpdateOut,
SubscriptionCancelOut,
SubscriptionImportOut,
]
]

View File

@ -367,6 +367,25 @@ class OrgOps:
return org
async def add_subscription_to_org(
self, subscription: Subscription, oid: UUID
) -> None:
"""Add subscription to existing org"""
org = await self.get_org_by_id(oid)
org.subscription = subscription
include = {"subscription"}
if subscription.status == PAUSED_PAYMENT_FAILED:
org.readOnly = True
org.readOnlyReason = REASON_PAUSED
include.add("readOnly")
include.add("readOnlyReason")
await self.orgs.find_one_and_update(
{"_id": org.id}, {"$set": org.dict(include=include)}
)
async def check_all_org_default_storages(self, storage_ops) -> None:
"""ensure all default storages references by this org actually exist

View File

@ -14,9 +14,11 @@ from .users import UserManager
from .utils import is_bool
from .models import (
SubscriptionCreate,
SubscriptionImport,
SubscriptionUpdate,
SubscriptionCancel,
SubscriptionCreateOut,
SubscriptionImportOut,
SubscriptionUpdateOut,
SubscriptionCancelOut,
Subscription,
@ -28,8 +30,10 @@ from .models import (
InviteAddedResponse,
User,
UserRole,
AddedResponseId,
UpdatedResponse,
PaginatedSubscriptionEventResponse,
REASON_CANCELED,
)
from .pagination import DEFAULT_PAGE_SIZE, paginated_format
from .utils import dt_now
@ -86,6 +90,19 @@ class SubOps:
return {"added": True, "id": new_org.id, "invited": invited, "token": token}
async def import_subscription(
self, sub_import: SubscriptionImport
) -> dict[str, Any]:
"""import subscription to existing org"""
subscription = Subscription(
subId=sub_import.subId, status=sub_import.status, planId=sub_import.planId
)
await self.org_ops.add_subscription_to_org(subscription, sub_import.oid)
await self.add_sub_event("import", sub_import, sub_import.oid)
return {"added": True, "id": sub_import.oid}
async def update_subscription(self, update: SubscriptionUpdate) -> dict[str, bool]:
"""update subs"""
@ -118,7 +135,7 @@ class SubOps:
deleted = False
await self.org_ops.update_read_only(
org, readOnly=True, readOnlyReason="subscriptionCanceled"
org, readOnly=True, readOnlyReason=REASON_CANCELED
)
if not org.subscription.readOnlyOnCancel:
@ -131,7 +148,12 @@ class SubOps:
async def add_sub_event(
self,
type_: str,
event: Union[SubscriptionCreate, SubscriptionUpdate, SubscriptionCancel],
event: Union[
SubscriptionCreate,
SubscriptionImport,
SubscriptionUpdate,
SubscriptionCancel,
],
oid: UUID,
) -> None:
"""add a subscription event to the db"""
@ -141,12 +163,17 @@ class SubOps:
data["oid"] = oid
await self.subs.insert_one(data)
def _get_sub_by_type_from_data(
self, data: dict[str, object]
) -> Union[SubscriptionCreateOut, SubscriptionUpdateOut, SubscriptionCancelOut]:
def _get_sub_by_type_from_data(self, data: dict[str, object]) -> Union[
SubscriptionCreateOut,
SubscriptionImportOut,
SubscriptionUpdateOut,
SubscriptionCancelOut,
]:
"""convert dict to propert background job type"""
if data["type"] == "create":
return SubscriptionCreateOut(**data)
if data["type"] == "import":
return SubscriptionImportOut(**data)
if data["type"] == "update":
return SubscriptionUpdateOut(**data)
return SubscriptionCancelOut(**data)
@ -164,7 +191,12 @@ class SubOps:
sort_direction: Optional[int] = -1,
) -> Tuple[
List[
Union[SubscriptionCreateOut, SubscriptionUpdateOut, SubscriptionCancelOut]
Union[
SubscriptionCreateOut,
SubscriptionImportOut,
SubscriptionUpdateOut,
SubscriptionCancelOut,
]
],
int,
]:
@ -289,6 +321,15 @@ def init_subs_api(
):
return await ops.create_new_subscription(create, user, request)
@app.post(
"/subscriptions/import",
tags=["subscriptions"],
dependencies=[Depends(user_or_shared_secret_dep)],
response_model=AddedResponseId,
)
async def import_sub(sub_import: SubscriptionImport):
return await ops.import_subscription(sub_import)
@app.post(
"/subscriptions/update",
tags=["subscriptions"],

View File

@ -1,6 +1,7 @@
import requests
from .conftest import API_PREFIX
from uuid import uuid4
new_subs_oid = None
@ -366,14 +367,57 @@ def test_cancel_sub_and_no_delete_org(admin_auth_headers):
assert r.json() == {"detail": "org_for_subscription_not_found"}
def test_subscription_events_log(admin_auth_headers):
def test_import_sub_invalid_org(admin_auth_headers):
r = requests.post(
f"{API_PREFIX}/subscriptions/import",
headers=admin_auth_headers,
json={
"subId": "345",
"planId": "basic",
"status": "active",
"oid": str(uuid4()),
},
)
assert r.status_code == 400
assert r.json() == {"detail": "invalid_org_id"}
def test_import_sub_existing_org(admin_auth_headers, non_default_org_id):
r = requests.post(
f"{API_PREFIX}/subscriptions/import",
headers=admin_auth_headers,
json={
"subId": "345",
"planId": "basic",
"status": "active",
"oid": non_default_org_id,
},
)
assert r.status_code == 200
assert r.json() == {"added": True, "id": non_default_org_id}
r = requests.get(
f"{API_PREFIX}/orgs/{non_default_org_id}", headers=admin_auth_headers
)
assert r.status_code == 200
data = r.json()
assert data["subscription"] == {
"subId": "345",
"status": "active",
"planId": "basic",
"futureCancelDate": None,
"readOnlyOnCancel": False,
}
def test_subscription_events_log(admin_auth_headers, non_default_org_id):
r = requests.get(f"{API_PREFIX}/subscriptions/events", headers=admin_auth_headers)
assert r.status_code == 200
data = r.json()
events = data["items"]
total = data["total"]
assert total == 6
assert total == 7
for event in events:
assert event["timestamp"]
@ -430,6 +474,13 @@ def test_subscription_events_log(admin_auth_headers):
},
{"subId": "123", "oid": new_subs_oid, "type": "cancel"},
{"subId": "234", "oid": new_subs_oid_2, "type": "cancel"},
{
"type": "import",
"subId": "345",
"oid": non_default_org_id,
"status": "active",
"planId": "basic",
},
]