diff --git a/backend/btrixcloud/models.py b/backend/btrixcloud/models.py index e277e34a..450df4cb 100644 --- a/backend/btrixcloud/models.py +++ b/backend/btrixcloud/models.py @@ -1128,6 +1128,23 @@ class SubscriptionCreateOut(SubscriptionCreate, SubscriptionEventOut): type: Literal["create"] = "create" +# ============================================================================ +class SubscriptionImport(BaseModel): + """import subscription to existing org""" + + subId: str + status: str + planId: str + oid: UUID + + +# ============================================================================ +class SubscriptionImportOut(SubscriptionImport, SubscriptionEventOut): + """Output model for subscription import event""" + + type: Literal["import"] = "import" + + # ============================================================================ class SubscriptionUpdate(BaseModel): """update subscription data""" @@ -2190,7 +2207,12 @@ class PaginatedSubscriptionEventResponse(PaginatedResponse): """Response model for paginated subscription events""" items: List[ - Union[SubscriptionCreateOut, SubscriptionUpdateOut, SubscriptionCancelOut] + Union[ + SubscriptionCreateOut, + SubscriptionUpdateOut, + SubscriptionCancelOut, + SubscriptionImportOut, + ] ] diff --git a/backend/btrixcloud/orgs.py b/backend/btrixcloud/orgs.py index 0768bb5a..aeeae67e 100644 --- a/backend/btrixcloud/orgs.py +++ b/backend/btrixcloud/orgs.py @@ -367,6 +367,25 @@ class OrgOps: return org + async def add_subscription_to_org( + self, subscription: Subscription, oid: UUID + ) -> None: + """Add subscription to existing org""" + org = await self.get_org_by_id(oid) + + org.subscription = subscription + include = {"subscription"} + + if subscription.status == PAUSED_PAYMENT_FAILED: + org.readOnly = True + org.readOnlyReason = REASON_PAUSED + include.add("readOnly") + include.add("readOnlyReason") + + await self.orgs.find_one_and_update( + {"_id": org.id}, {"$set": org.dict(include=include)} + ) + async def check_all_org_default_storages(self, storage_ops) -> None: """ensure all default storages references by this org actually exist diff --git a/backend/btrixcloud/subs.py b/backend/btrixcloud/subs.py index 6a3fe28d..b87a60f3 100644 --- a/backend/btrixcloud/subs.py +++ b/backend/btrixcloud/subs.py @@ -14,9 +14,11 @@ from .users import UserManager from .utils import is_bool from .models import ( SubscriptionCreate, + SubscriptionImport, SubscriptionUpdate, SubscriptionCancel, SubscriptionCreateOut, + SubscriptionImportOut, SubscriptionUpdateOut, SubscriptionCancelOut, Subscription, @@ -28,8 +30,10 @@ from .models import ( InviteAddedResponse, User, UserRole, + AddedResponseId, UpdatedResponse, PaginatedSubscriptionEventResponse, + REASON_CANCELED, ) from .pagination import DEFAULT_PAGE_SIZE, paginated_format from .utils import dt_now @@ -86,6 +90,19 @@ class SubOps: return {"added": True, "id": new_org.id, "invited": invited, "token": token} + async def import_subscription( + self, sub_import: SubscriptionImport + ) -> dict[str, Any]: + """import subscription to existing org""" + subscription = Subscription( + subId=sub_import.subId, status=sub_import.status, planId=sub_import.planId + ) + await self.org_ops.add_subscription_to_org(subscription, sub_import.oid) + + await self.add_sub_event("import", sub_import, sub_import.oid) + + return {"added": True, "id": sub_import.oid} + async def update_subscription(self, update: SubscriptionUpdate) -> dict[str, bool]: """update subs""" @@ -118,7 +135,7 @@ class SubOps: deleted = False await self.org_ops.update_read_only( - org, readOnly=True, readOnlyReason="subscriptionCanceled" + org, readOnly=True, readOnlyReason=REASON_CANCELED ) if not org.subscription.readOnlyOnCancel: @@ -131,7 +148,12 @@ class SubOps: async def add_sub_event( self, type_: str, - event: Union[SubscriptionCreate, SubscriptionUpdate, SubscriptionCancel], + event: Union[ + SubscriptionCreate, + SubscriptionImport, + SubscriptionUpdate, + SubscriptionCancel, + ], oid: UUID, ) -> None: """add a subscription event to the db""" @@ -141,12 +163,17 @@ class SubOps: data["oid"] = oid await self.subs.insert_one(data) - def _get_sub_by_type_from_data( - self, data: dict[str, object] - ) -> Union[SubscriptionCreateOut, SubscriptionUpdateOut, SubscriptionCancelOut]: + def _get_sub_by_type_from_data(self, data: dict[str, object]) -> Union[ + SubscriptionCreateOut, + SubscriptionImportOut, + SubscriptionUpdateOut, + SubscriptionCancelOut, + ]: """convert dict to propert background job type""" if data["type"] == "create": return SubscriptionCreateOut(**data) + if data["type"] == "import": + return SubscriptionImportOut(**data) if data["type"] == "update": return SubscriptionUpdateOut(**data) return SubscriptionCancelOut(**data) @@ -164,7 +191,12 @@ class SubOps: sort_direction: Optional[int] = -1, ) -> Tuple[ List[ - Union[SubscriptionCreateOut, SubscriptionUpdateOut, SubscriptionCancelOut] + Union[ + SubscriptionCreateOut, + SubscriptionImportOut, + SubscriptionUpdateOut, + SubscriptionCancelOut, + ] ], int, ]: @@ -289,6 +321,15 @@ def init_subs_api( ): return await ops.create_new_subscription(create, user, request) + @app.post( + "/subscriptions/import", + tags=["subscriptions"], + dependencies=[Depends(user_or_shared_secret_dep)], + response_model=AddedResponseId, + ) + async def import_sub(sub_import: SubscriptionImport): + return await ops.import_subscription(sub_import) + @app.post( "/subscriptions/update", tags=["subscriptions"], diff --git a/backend/test/test_org_subs.py b/backend/test/test_org_subs.py index 48804b7a..b6a1ccad 100644 --- a/backend/test/test_org_subs.py +++ b/backend/test/test_org_subs.py @@ -1,6 +1,7 @@ import requests from .conftest import API_PREFIX +from uuid import uuid4 new_subs_oid = None @@ -366,14 +367,57 @@ def test_cancel_sub_and_no_delete_org(admin_auth_headers): assert r.json() == {"detail": "org_for_subscription_not_found"} -def test_subscription_events_log(admin_auth_headers): +def test_import_sub_invalid_org(admin_auth_headers): + r = requests.post( + f"{API_PREFIX}/subscriptions/import", + headers=admin_auth_headers, + json={ + "subId": "345", + "planId": "basic", + "status": "active", + "oid": str(uuid4()), + }, + ) + assert r.status_code == 400 + assert r.json() == {"detail": "invalid_org_id"} + + +def test_import_sub_existing_org(admin_auth_headers, non_default_org_id): + r = requests.post( + f"{API_PREFIX}/subscriptions/import", + headers=admin_auth_headers, + json={ + "subId": "345", + "planId": "basic", + "status": "active", + "oid": non_default_org_id, + }, + ) + assert r.status_code == 200 + assert r.json() == {"added": True, "id": non_default_org_id} + + r = requests.get( + f"{API_PREFIX}/orgs/{non_default_org_id}", headers=admin_auth_headers + ) + assert r.status_code == 200 + data = r.json() + assert data["subscription"] == { + "subId": "345", + "status": "active", + "planId": "basic", + "futureCancelDate": None, + "readOnlyOnCancel": False, + } + + +def test_subscription_events_log(admin_auth_headers, non_default_org_id): r = requests.get(f"{API_PREFIX}/subscriptions/events", headers=admin_auth_headers) assert r.status_code == 200 data = r.json() events = data["items"] total = data["total"] - assert total == 6 + assert total == 7 for event in events: assert event["timestamp"] @@ -430,6 +474,13 @@ def test_subscription_events_log(admin_auth_headers): }, {"subId": "123", "oid": new_subs_oid, "type": "cancel"}, {"subId": "234", "oid": new_subs_oid_2, "type": "cancel"}, + { + "type": "import", + "subId": "345", + "oid": non_default_org_id, + "status": "active", + "planId": "basic", + }, ]