mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:09:24 +08:00
refactor: select in account_service (AccountService class) (#34496)
This commit is contained in:
@@ -144,22 +144,26 @@ class AccountService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_user(user_id: str) -> None | Account:
|
def load_user(user_id: str) -> None | Account:
|
||||||
account = db.session.query(Account).filter_by(id=user_id).first()
|
account = db.session.get(Account, user_id)
|
||||||
if not account:
|
if not account:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if account.status == AccountStatus.BANNED:
|
if account.status == AccountStatus.BANNED:
|
||||||
raise Unauthorized("Account is banned.")
|
raise Unauthorized("Account is banned.")
|
||||||
|
|
||||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
current_tenant = db.session.scalar(
|
||||||
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
if current_tenant:
|
if current_tenant:
|
||||||
account.set_tenant_id(current_tenant.tenant_id)
|
account.set_tenant_id(current_tenant.tenant_id)
|
||||||
else:
|
else:
|
||||||
available_ta = (
|
available_ta = db.session.scalar(
|
||||||
db.session.query(TenantAccountJoin)
|
select(TenantAccountJoin)
|
||||||
.filter_by(account_id=account.id)
|
.where(TenantAccountJoin.account_id == account.id)
|
||||||
.order_by(TenantAccountJoin.id.asc())
|
.order_by(TenantAccountJoin.id.asc())
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not available_ta:
|
if not available_ta:
|
||||||
return None
|
return None
|
||||||
@@ -195,7 +199,7 @@ class AccountService:
|
|||||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||||
"""authenticate account with email and password"""
|
"""authenticate account with email and password"""
|
||||||
|
|
||||||
account = db.session.query(Account).filter_by(email=email).first()
|
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||||
if not account:
|
if not account:
|
||||||
raise AccountPasswordError("Invalid email or password.")
|
raise AccountPasswordError("Invalid email or password.")
|
||||||
|
|
||||||
@@ -371,8 +375,10 @@ class AccountService:
|
|||||||
"""Link account integrate"""
|
"""Link account integrate"""
|
||||||
try:
|
try:
|
||||||
# Query whether there is an existing binding record for the same provider
|
# Query whether there is an existing binding record for the same provider
|
||||||
account_integrate: AccountIntegrate | None = (
|
account_integrate: AccountIntegrate | None = db.session.scalar(
|
||||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
select(AccountIntegrate)
|
||||||
|
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if account_integrate:
|
if account_integrate:
|
||||||
@@ -416,7 +422,9 @@ class AccountService:
|
|||||||
def update_account_email(account: Account, email: str) -> Account:
|
def update_account_email(account: Account, email: str) -> Account:
|
||||||
"""Update account email"""
|
"""Update account email"""
|
||||||
account.email = email
|
account.email = email
|
||||||
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
|
account_integrate = db.session.scalar(
|
||||||
|
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
|
||||||
|
)
|
||||||
if account_integrate:
|
if account_integrate:
|
||||||
db.session.delete(account_integrate)
|
db.session.delete(account_integrate)
|
||||||
db.session.add(account)
|
db.session.add(account)
|
||||||
@@ -818,7 +826,7 @@ class AccountService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
account = db.session.query(Account).where(Account.email == email).first()
|
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||||
if not account:
|
if not account:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1018,7 +1026,7 @@ class AccountService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_email_unique(email: str) -> bool:
|
def check_email_unique(email: str) -> bool:
|
||||||
return db.session.query(Account).filter_by(email=email).first() is None
|
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||||
|
|
||||||
|
|
||||||
class TenantService:
|
class TenantService:
|
||||||
|
|||||||
@@ -173,9 +173,7 @@ class TestAccountService:
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||||
|
|
||||||
# Setup smart database query mock
|
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||||
query_results = {("Account", "email", "test@example.com"): mock_account}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
mock_password_dependencies["compare_password"].return_value = True
|
mock_password_dependencies["compare_password"].return_value = True
|
||||||
|
|
||||||
@@ -188,9 +186,7 @@ class TestAccountService:
|
|||||||
|
|
||||||
def test_authenticate_account_not_found(self, mock_db_dependencies):
|
def test_authenticate_account_not_found(self, mock_db_dependencies):
|
||||||
"""Test authentication when account does not exist."""
|
"""Test authentication when account does not exist."""
|
||||||
# Setup smart database query mock - no matching results
|
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||||
query_results = {("Account", "email", "notfound@example.com"): None}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Execute test and verify exception
|
# Execute test and verify exception
|
||||||
self._assert_exception_raised(
|
self._assert_exception_raised(
|
||||||
@@ -202,9 +198,7 @@ class TestAccountService:
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
||||||
|
|
||||||
# Setup smart database query mock
|
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||||
query_results = {("Account", "email", "banned@example.com"): mock_account}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Execute test and verify exception
|
# Execute test and verify exception
|
||||||
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
|
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
|
||||||
@@ -214,9 +208,7 @@ class TestAccountService:
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||||
|
|
||||||
# Setup smart database query mock
|
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||||
query_results = {("Account", "email", "test@example.com"): mock_account}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
mock_password_dependencies["compare_password"].return_value = False
|
mock_password_dependencies["compare_password"].return_value = False
|
||||||
|
|
||||||
@@ -230,9 +222,7 @@ class TestAccountService:
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
|
||||||
|
|
||||||
# Setup smart database query mock
|
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||||
query_results = {("Account", "email", "pending@example.com"): mock_account}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
mock_password_dependencies["compare_password"].return_value = True
|
mock_password_dependencies["compare_password"].return_value = True
|
||||||
|
|
||||||
@@ -422,12 +412,8 @@ class TestAccountService:
|
|||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||||
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
|
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
|
||||||
|
|
||||||
# Setup smart database query mock
|
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||||
query_results = {
|
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
|
||||||
("Account", "id", "user-123"): mock_account,
|
|
||||||
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
|
|
||||||
}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Mock datetime
|
# Mock datetime
|
||||||
with patch("services.account_service.datetime") as mock_datetime:
|
with patch("services.account_service.datetime") as mock_datetime:
|
||||||
@@ -444,9 +430,7 @@ class TestAccountService:
|
|||||||
|
|
||||||
def test_load_user_not_found(self, mock_db_dependencies):
|
def test_load_user_not_found(self, mock_db_dependencies):
|
||||||
"""Test user loading when user does not exist."""
|
"""Test user loading when user does not exist."""
|
||||||
# Setup smart database query mock - no matching results
|
mock_db_dependencies["db"].session.get.return_value = None
|
||||||
query_results = {("Account", "id", "non-existent-user"): None}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Execute test
|
# Execute test
|
||||||
result = AccountService.load_user("non-existent-user")
|
result = AccountService.load_user("non-existent-user")
|
||||||
@@ -459,9 +443,7 @@ class TestAccountService:
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
||||||
|
|
||||||
# Setup smart database query mock
|
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||||
query_results = {("Account", "id", "user-123"): mock_account}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Execute test and verify exception
|
# Execute test and verify exception
|
||||||
self._assert_exception_raised(
|
self._assert_exception_raised(
|
||||||
@@ -476,13 +458,9 @@ class TestAccountService:
|
|||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||||
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
|
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
|
||||||
|
|
||||||
# Setup smart database query mock for complex scenario
|
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||||
query_results = {
|
# First scalar: current tenant (None), second scalar: available tenant
|
||||||
("Account", "id", "user-123"): mock_account,
|
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
|
||||||
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
|
|
||||||
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
|
|
||||||
}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Mock datetime
|
# Mock datetime
|
||||||
with patch("services.account_service.datetime") as mock_datetime:
|
with patch("services.account_service.datetime") as mock_datetime:
|
||||||
@@ -503,13 +481,9 @@ class TestAccountService:
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||||
|
|
||||||
# Setup smart database query mock for no tenants scenario
|
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||||
query_results = {
|
# First scalar: current tenant (None), second scalar: available tenant (None)
|
||||||
("Account", "id", "user-123"): mock_account,
|
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
|
||||||
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
|
|
||||||
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
|
|
||||||
}
|
|
||||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
|
||||||
|
|
||||||
# Mock datetime
|
# Mock datetime
|
||||||
with patch("services.account_service.datetime") as mock_datetime:
|
with patch("services.account_service.datetime") as mock_datetime:
|
||||||
|
|||||||
Reference in New Issue
Block a user