chore: case insensitive email (#29978)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
非法操作
2026-01-13 15:42:44 +08:00
committed by GitHub
parent 0e33dfb5c2
commit 491e1fd6a4
27 changed files with 1611 additions and 376 deletions

View File

@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@@ -1363,16 +1378,22 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
normalized_email = email.lower()
"""Invite new member"""
with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
name = email.split("@")[0]
name = normalized_email.split("@")[0]
account = cls.register(
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
email=normalized_email,
name=name,
language=language,
status=AccountStatus.PENDING,
is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@@ -1394,7 +1415,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
to=email,
to=account.email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@@ -1493,6 +1514,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)

View File

@@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password
from models import Account, AccountStatus
from models.model import App, EndUser, Site
from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
raise AccountNotFoundError()
@@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).where(Account.email == email).first()
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
return None