users: add case-insensitive index to maintain backwards compatibility with fastapi-users (#1319)
follow up to #1290 Based on implementation in: https://github.com/fastapi-users/fastapi-users-db-mongodb/blob/main/fastapi_users_db_mongodb/__init__.py
This commit is contained in:
		
							parent
							
								
									3c884f94c9
								
							
						
					
					
						commit
						c1d3beda9c
					
				| @ -19,6 +19,7 @@ from fastapi import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from pymongo.errors import DuplicateKeyError | from pymongo.errors import DuplicateKeyError | ||||||
|  | from pymongo.collation import Collation | ||||||
| 
 | 
 | ||||||
| from .models import ( | from .models import ( | ||||||
|     UserCreate, |     UserCreate, | ||||||
| @ -65,6 +66,8 @@ class UserManager: | |||||||
|         self.invites = invites |         self.invites = invites | ||||||
|         self.org_ops = None |         self.org_ops = None | ||||||
| 
 | 
 | ||||||
|  |         self.email_collation = Collation("en", strength=2) | ||||||
|  | 
 | ||||||
|         self.registration_enabled = is_bool(os.environ.get("REGISTRATION_ENABLED")) |         self.registration_enabled = is_bool(os.environ.get("REGISTRATION_ENABLED")) | ||||||
| 
 | 
 | ||||||
|     # pylint: disable=attribute-defined-outside-init |     # pylint: disable=attribute-defined-outside-init | ||||||
| @ -78,6 +81,13 @@ class UserManager: | |||||||
|         """init lookup index""" |         """init lookup index""" | ||||||
|         await self.users.create_index("id", unique=True) |         await self.users.create_index("id", unique=True) | ||||||
|         await self.users.create_index("email", unique=True) |         await self.users.create_index("email", unique=True) | ||||||
|  | 
 | ||||||
|  |         await self.users.create_index( | ||||||
|  |             "email", | ||||||
|  |             name="case_insensitive_email_index", | ||||||
|  |             collation=self.email_collation, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|         # Expire failed logins object after one hour |         # Expire failed logins object after one hour | ||||||
|         await self.failed_logins.create_index("attempted", expireAfterSeconds=3600) |         await self.failed_logins.create_index("attempted", expireAfterSeconds=3600) | ||||||
| 
 | 
 | ||||||
| @ -379,7 +389,9 @@ class UserManager: | |||||||
| 
 | 
 | ||||||
|     async def get_by_email(self, email: str) -> Optional[User]: |     async def get_by_email(self, email: str) -> Optional[User]: | ||||||
|         """get user by email""" |         """get user by email""" | ||||||
|         user = await self.users.find_one({"email": email}) |         user = await self.users.find_one( | ||||||
|  |             {"email": email}, collation=self.email_collation | ||||||
|  |         ) | ||||||
|         if not user: |         if not user: | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
| @ -535,7 +547,9 @@ class UserManager: | |||||||
| 
 | 
 | ||||||
|     async def reset_failed_logins(self, email: str) -> None: |     async def reset_failed_logins(self, email: str) -> None: | ||||||
|         """Reset consecutive failed login attempts by deleting FailedLogin object""" |         """Reset consecutive failed login attempts by deleting FailedLogin object""" | ||||||
|         await self.failed_logins.delete_one({"email": email}) |         await self.failed_logins.delete_one( | ||||||
|  |             {"email": email}, collation=self.email_collation | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     async def inc_failed_logins(self, email: str) -> None: |     async def inc_failed_logins(self, email: str) -> None: | ||||||
|         """Inc consecutive failed login attempts for user by 1 |         """Inc consecutive failed login attempts for user by 1 | ||||||
| @ -552,11 +566,14 @@ class UserManager: | |||||||
|                 "$inc": {"count": 1}, |                 "$inc": {"count": 1}, | ||||||
|             }, |             }, | ||||||
|             upsert=True, |             upsert=True, | ||||||
|  |             collation=self.email_collation, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     async def get_failed_logins_count(self, email: str) -> int: |     async def get_failed_logins_count(self, email: str) -> int: | ||||||
|         """Get failed login attempts for user, falling back to 0""" |         """Get failed login attempts for user, falling back to 0""" | ||||||
|         failed_login = await self.failed_logins.find_one({"email": email}) |         failed_login = await self.failed_logins.find_one( | ||||||
|  |             {"email": email}, collation=self.email_collation | ||||||
|  |         ) | ||||||
|         if not failed_login: |         if not failed_login: | ||||||
|             return 0 |             return 0 | ||||||
|         return failed_login.get("count", 0) |         return failed_login.get("count", 0) | ||||||
|  | |||||||
| @ -15,7 +15,8 @@ ADMIN_PW = "PASSW0RD!" | |||||||
| VIEWER_USERNAME = "viewer@example.com" | VIEWER_USERNAME = "viewer@example.com" | ||||||
| VIEWER_PW = "viewerPASSW0RD!" | VIEWER_PW = "viewerPASSW0RD!" | ||||||
| 
 | 
 | ||||||
| CRAWLER_USERNAME = "crawler@example.com" | CRAWLER_USERNAME = "CraWleR@example.com" | ||||||
|  | CRAWLER_USERNAME_LOWERCASE = "crawler@example.com" | ||||||
| CRAWLER_PW = "crawlerPASSWORD!" | CRAWLER_PW = "crawlerPASSWORD!" | ||||||
| 
 | 
 | ||||||
| _admin_config_id = None | _admin_config_id = None | ||||||
|  | |||||||
| @ -4,6 +4,8 @@ import time | |||||||
| from .conftest import ( | from .conftest import ( | ||||||
|     API_PREFIX, |     API_PREFIX, | ||||||
|     CRAWLER_USERNAME, |     CRAWLER_USERNAME, | ||||||
|  |     CRAWLER_USERNAME_LOWERCASE, | ||||||
|  |     CRAWLER_PW, | ||||||
|     ADMIN_PW, |     ADMIN_PW, | ||||||
|     ADMIN_USERNAME, |     ADMIN_USERNAME, | ||||||
|     FINISHED_STATES, |     FINISHED_STATES, | ||||||
| @ -14,7 +16,6 @@ VALID_USER_PW = "validpassw0rd!" | |||||||
| VALID_USER_PW_RESET = "new!password" | VALID_USER_PW_RESET = "new!password" | ||||||
| VALID_USER_PW_RESET_AGAIN = "new!password1" | VALID_USER_PW_RESET_AGAIN = "new!password1" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| my_id = None | my_id = None | ||||||
| valid_user_headers = None | valid_user_headers = None | ||||||
| 
 | 
 | ||||||
| @ -71,6 +72,20 @@ def test_me_id(admin_auth_headers, default_org_id): | |||||||
|     assert r.status_code == 404 |     assert r.status_code == 404 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_login_case_insensitive_email(): | ||||||
|  |     r = requests.post( | ||||||
|  |         f"{API_PREFIX}/auth/jwt/login", | ||||||
|  |         data={ | ||||||
|  |             "username": CRAWLER_USERNAME_LOWERCASE, | ||||||
|  |             "password": CRAWLER_PW, | ||||||
|  |             "grant_type": "password", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     data = r.json() | ||||||
|  |     assert r.status_code == 200 | ||||||
|  |     assert data["access_token"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_add_user_to_org_invalid_password(admin_auth_headers, default_org_id): | def test_add_user_to_org_invalid_password(admin_auth_headers, default_org_id): | ||||||
|     r = requests.post( |     r = requests.post( | ||||||
|         f"{API_PREFIX}/orgs/{default_org_id}/add-user", |         f"{API_PREFIX}/orgs/{default_org_id}/add-user", | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user