Include user and user org info in login response (#2014)

Fixes #2013 

Adds the `/users/me` response data to the API login endpoint response
under the key `user_info` and adds a test.
This commit is contained in:
Tessa Walsh 2024-08-12 21:51:42 -04:00 committed by GitHub
parent 1a6892572d
commit 916813af2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 5 deletions

View File

@ -21,7 +21,7 @@ from fastapi import (
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from .models import User
from .models import User, UserOut
from .utils import dt_now
@ -57,6 +57,7 @@ class BearerResponse(BaseModel):
access_token: str
token_type: str
user_info: UserOut
# ============================================================================
@ -181,10 +182,12 @@ def init_jwt_auth(user_manager):
auth_jwt_router = APIRouter()
def get_bearer_response(user: User):
def get_bearer_response(user: User, user_info: UserOut):
"""get token, return bearer response for user"""
token = create_access_token(user)
return BearerResponse(access_token=token, token_type="bearer")
return BearerResponse(
access_token=token, token_type="bearer", user_info=user_info
)
@auth_jwt_router.post("/login", response_model=BearerResponse)
async def login(
@ -246,10 +249,12 @@ def init_jwt_auth(user_manager):
# successfully logged in, reset failed logins, return user
await user_manager.reset_failed_logins(login_email)
return get_bearer_response(user)
user_info = await user_manager.get_user_info_with_orgs(user)
return get_bearer_response(user, user_info)
@auth_jwt_router.post("/refresh", response_model=BearerResponse)
async def refresh_jwt(user=Depends(current_active_user)):
return get_bearer_response(user)
user_info = await user_manager.get_user_info_with_orgs(user)
return get_bearer_response(user, user_info)
return auth_jwt_router, current_active_user, shared_secret_or_active_user

View File

@ -78,6 +78,45 @@ def test_me_id(admin_auth_headers, default_org_id):
assert r.status_code == 404
def test_login_user_info(admin_auth_headers, crawler_userid, default_org_id):
# Get default org info for comparison
r = requests.get(f"{API_PREFIX}/orgs", headers=admin_auth_headers)
default_org = [org for org in r.json()["items"] if org["default"]][0]
# Log in and check response
r = requests.post(
f"{API_PREFIX}/auth/jwt/login",
data={
"username": CRAWLER_USERNAME,
"password": CRAWLER_PW,
"grant_type": "password",
},
)
data = r.json()
assert r.status_code == 200
assert data["access_token"]
assert data["token_type"] == "bearer"
user_info = data["user_info"]
assert user_info
assert user_info["id"] == crawler_userid
assert user_info["name"] == "new-crawler"
assert user_info["email"] == CRAWLER_USERNAME
assert user_info["is_superuser"] is False
assert user_info["is_verified"]
user_orgs = user_info["orgs"]
assert len(user_orgs) == 1
org = user_orgs[0]
assert org["id"] == default_org_id
assert org["name"] == default_org["name"]
assert org["slug"] == default_org["slug"]
assert org["default"]
assert org["role"] == 20
def test_login_case_insensitive_email():
r = requests.post(
f"{API_PREFIX}/auth/jwt/login",