Ensure email comparisons are case-insensitive, emails stored as lowercase (#2084) (#2086) (fixes from 1.11.7)

- Add a custom EmailStr type which lowercases the full e-mail, not just
the domain.
- Ensure EmailStr is used throughout wherever e-mails are used, both for
invites and user models
- Tests: update to check for lowercase email responses, e-mails returned
from APIs are always lowercase
- Tests: remove tests where '@' was ur-lencoded, should not be possible
since POSTing JSON and no url-decoding is done/expected. E-mails should
have '@' present.
- Fixes #2083 where invites were rejected due to case differences
- CI: pin pymongo dependency due to latest releases update, update python used for CI
This commit is contained in:
Ilya Kreymer 2024-09-19 12:20:34 -07:00 committed by GitHub
parent a8f4f8cfc3
commit feb6b1f26c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 36 additions and 23 deletions

View File

@ -81,9 +81,9 @@ jobs:
helm upgrade --install -f ./chart/values.yaml -f ./chart/test/test.yaml btrix ./chart/ helm upgrade --install -f ./chart/values.yaml -f ./chart/test/test.yaml btrix ./chart/
- name: Install Python - name: Install Python
uses: actions/setup-python@v3 uses: actions/setup-python@v5
with: with:
python-version: '3.9' python-version: 3.x
- name: Install Python Libs - name: Install Python Libs
run: pip install -r ./backend/test-requirements.txt run: pip install -r ./backend/test-requirements.txt

View File

@ -13,6 +13,7 @@ from fastapi import HTTPException
from .pagination import DEFAULT_PAGE_SIZE from .pagination import DEFAULT_PAGE_SIZE
from .models import ( from .models import (
EmailStr,
UserRole, UserRole,
InvitePending, InvitePending,
InviteRequest, InviteRequest,
@ -133,7 +134,10 @@ class InviteOps:
) )
async def get_valid_invite( async def get_valid_invite(
self, invite_token: UUID, email: Optional[str], userid: Optional[UUID] = None self,
invite_token: UUID,
email: Optional[EmailStr],
userid: Optional[UUID] = None,
) -> InvitePending: ) -> InvitePending:
"""Retrieve a valid invite data from db, or throw if invalid""" """Retrieve a valid invite data from db, or throw if invalid"""
token_hash = get_hash(invite_token) token_hash = get_hash(invite_token)
@ -156,7 +160,7 @@ class InviteOps:
await self.invites.delete_one({"_id": invite_token}) await self.invites.delete_one({"_id": invite_token})
async def remove_invite_by_email( async def remove_invite_by_email(
self, email: str, oid: Optional[UUID] = None self, email: EmailStr, oid: Optional[UUID] = None
) -> Any: ) -> Any:
"""remove invite from invite list by email""" """remove invite from invite list by email"""
query: dict[str, object] = {"email": email} query: dict[str, object] = {"email": email}

View File

@ -15,7 +15,8 @@ from pydantic import (
Field, Field,
HttpUrl as HttpUrlNonStr, HttpUrl as HttpUrlNonStr,
AnyHttpUrl as AnyHttpUrlNonStr, AnyHttpUrl as AnyHttpUrlNonStr,
EmailStr, EmailStr as CasedEmailStr,
validate_email,
RootModel, RootModel,
BeforeValidator, BeforeValidator,
TypeAdapter, TypeAdapter,
@ -47,6 +48,15 @@ HttpUrl = Annotated[
] ]
# pylint: disable=too-few-public-methods
class EmailStr(CasedEmailStr):
"""EmailStr type that lowercases the full email"""
@classmethod
def _validate(cls, value: CasedEmailStr, /) -> CasedEmailStr:
return validate_email(value)[1].lower()
# pylint: disable=invalid-name, too-many-lines # pylint: disable=invalid-name, too-many-lines
# ============================================================================ # ============================================================================
class UserRole(IntEnum): class UserRole(IntEnum):
@ -70,11 +80,11 @@ class InvitePending(BaseMongoModel):
id: UUID id: UUID
created: datetime created: datetime
tokenHash: str tokenHash: str
inviterEmail: str inviterEmail: EmailStr
fromSuperuser: Optional[bool] = False fromSuperuser: Optional[bool] = False
oid: Optional[UUID] = None oid: Optional[UUID] = None
role: UserRole = UserRole.VIEWER role: UserRole = UserRole.VIEWER
email: Optional[str] = "" email: Optional[EmailStr] = None
# set if existing user # set if existing user
userid: Optional[UUID] = None userid: Optional[UUID] = None
@ -84,13 +94,13 @@ class InviteOut(BaseModel):
"""Single invite output model""" """Single invite output model"""
created: datetime created: datetime
inviterEmail: str inviterEmail: EmailStr
inviterName: str inviterName: str
oid: Optional[UUID] = None oid: Optional[UUID] = None
orgName: Optional[str] = None orgName: Optional[str] = None
orgSlug: Optional[str] = None orgSlug: Optional[str] = None
role: UserRole = UserRole.VIEWER role: UserRole = UserRole.VIEWER
email: Optional[str] = "" email: Optional[EmailStr] = None
firstOrgAdmin: Optional[bool] = None firstOrgAdmin: Optional[bool] = None
@ -98,7 +108,7 @@ class InviteOut(BaseModel):
class InviteRequest(BaseModel): class InviteRequest(BaseModel):
"""Request to invite another user""" """Request to invite another user"""
email: str email: EmailStr
# ============================================================================ # ============================================================================
@ -1179,7 +1189,7 @@ class SubscriptionCreate(BaseModel):
status: str status: str
planId: str planId: str
firstAdminInviteEmail: str firstAdminInviteEmail: EmailStr
quotas: Optional[OrgQuotas] = None quotas: Optional[OrgQuotas] = None

View File

@ -8,7 +8,6 @@ import json
import math import math
import os import os
import time import time
import urllib.parse
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@ -1614,9 +1613,7 @@ def init_orgs_api(
async def delete_invite( async def delete_invite(
invite: RemovePendingInvite, org: Organization = Depends(org_owner_dep) invite: RemovePendingInvite, org: Organization = Depends(org_owner_dep)
): ):
# URL decode email just in case result = await user_manager.invites.remove_invite_by_email(invite.email, org.id)
email = urllib.parse.unquote(invite.email)
result = await user_manager.invites.remove_invite_by_email(email, org.id)
if result.deleted_count > 0: if result.deleted_count > 0:
return { return {
"removed": True, "removed": True,

View File

@ -8,8 +8,6 @@ import asyncio
from typing import Optional, List, TYPE_CHECKING, cast, Callable from typing import Optional, List, TYPE_CHECKING, cast, Callable
from pydantic import EmailStr
from fastapi import ( from fastapi import (
Request, Request,
HTTPException, HTTPException,
@ -22,6 +20,7 @@ from pymongo.errors import DuplicateKeyError
from pymongo.collation import Collation from pymongo.collation import Collation
from .models import ( from .models import (
EmailStr,
UserCreate, UserCreate,
UserUpdateEmailName, UserUpdateEmailName,
UserUpdatePassword, UserUpdatePassword,
@ -685,7 +684,7 @@ def init_users_router(
return await user_manager.invites.get_invite_out(invite, user_manager, True) return await user_manager.invites.get_invite_out(invite, user_manager, True)
@users_router.get("/invite/{token}", tags=["invites"], response_model=InviteOut) @users_router.get("/invite/{token}", tags=["invites"], response_model=InviteOut)
async def get_invite_info(token: UUID, email: str): async def get_invite_info(token: UUID, email: EmailStr):
invite = await user_manager.invites.get_valid_invite(token, email) invite = await user_manager.invites.get_valid_invite(token, email)
return await user_manager.invites.get_invite_out(invite, user_manager, True) return await user_manager.invites.get_invite_out(invite, user_manager, True)

View File

@ -2,6 +2,7 @@ gunicorn
uvicorn[standard] uvicorn[standard]
fastapi==0.103.2 fastapi==0.103.2
motor==3.3.1 motor==3.3.1
pymongo==4.8.0
passlib passlib
PyJWT==2.8.0 PyJWT==2.8.0
pydantic==2.8.2 pydantic==2.8.2

View File

@ -360,16 +360,18 @@ def test_get_pending_org_invites(
("user+comment-org@example.com", "user+comment-org@example.com"), ("user+comment-org@example.com", "user+comment-org@example.com"),
# URL encoded email address with comments # URL encoded email address with comments
( (
"user%2Bcomment-encoded-org%40example.com", "user%2Bcomment-encoded-org@example.com",
"user+comment-encoded-org@example.com", "user+comment-encoded-org@example.com",
), ),
# User email with diacritic characters # User email with diacritic characters
("diacritic-tést-org@example.com", "diacritic-tést-org@example.com"), ("diacritic-tést-org@example.com", "diacritic-tést-org@example.com"),
# User email with encoded diacritic characters # User email with encoded diacritic characters
( (
"diacritic-t%C3%A9st-encoded-org%40example.com", "diacritic-t%C3%A9st-encoded-org@example.com",
"diacritic-tést-encoded-org@example.com", "diacritic-tést-encoded-org@example.com",
), ),
# User email with upper case characters, stored as all lowercase
("exampleName@EXAMple.com", "examplename@example.com"),
], ],
) )
def test_send_and_accept_org_invite( def test_send_and_accept_org_invite(

View File

@ -12,7 +12,7 @@ existing_user_invite_token = None
VALID_PASSWORD = "ValidPassW0rd!" VALID_PASSWORD = "ValidPassW0rd!"
invite_email = "test-user@example.com" invite_email = "test-User@EXample.com"
def test_create_sub_org_invalid_auth(crawler_auth_headers): def test_create_sub_org_invalid_auth(crawler_auth_headers):

View File

@ -50,7 +50,7 @@ def test_me_with_orgs(crawler_auth_headers, default_org_id):
assert r.status_code == 200 assert r.status_code == 200
data = r.json() data = r.json()
assert data["email"] == CRAWLER_USERNAME assert data["email"] == CRAWLER_USERNAME_LOWERCASE
assert data["id"] assert data["id"]
# assert data["is_active"] # assert data["is_active"]
assert data["is_superuser"] is False assert data["is_superuser"] is False
@ -102,7 +102,7 @@ def test_login_user_info(admin_auth_headers, crawler_userid, default_org_id):
assert user_info["id"] == crawler_userid assert user_info["id"] == crawler_userid
assert user_info["name"] == "new-crawler" assert user_info["name"] == "new-crawler"
assert user_info["email"] == CRAWLER_USERNAME assert user_info["email"] == CRAWLER_USERNAME_LOWERCASE
assert user_info["is_superuser"] is False assert user_info["is_superuser"] is False
assert user_info["is_verified"] assert user_info["is_verified"]