refactor: select in account_service (AccountService class) (#34496)

This commit is contained in:
Renzo
2026-04-03 05:41:46 +02:00
committed by GitHub
parent e55bd61c17
commit 33d4fd357c
2 changed files with 35 additions and 53 deletions

View File

@@ -144,22 +144,26 @@ class AccountService:
@staticmethod @staticmethod
def load_user(user_id: str) -> None | Account: def load_user(user_id: str) -> None | Account:
account = db.session.query(Account).filter_by(id=user_id).first() account = db.session.get(Account, user_id)
if not account: if not account:
return None return None
if account.status == AccountStatus.BANNED: if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.") raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() current_tenant = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
)
if current_tenant: if current_tenant:
account.set_tenant_id(current_tenant.tenant_id) account.set_tenant_id(current_tenant.tenant_id)
else: else:
available_ta = ( available_ta = db.session.scalar(
db.session.query(TenantAccountJoin) select(TenantAccountJoin)
.filter_by(account_id=account.id) .where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc()) .order_by(TenantAccountJoin.id.asc())
.first() .limit(1)
) )
if not available_ta: if not available_ta:
return None return None
@@ -195,7 +199,7 @@ class AccountService:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first() account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account: if not account:
raise AccountPasswordError("Invalid email or password.") raise AccountPasswordError("Invalid email or password.")
@@ -371,8 +375,10 @@ class AccountService:
"""Link account integrate""" """Link account integrate"""
try: try:
# Query whether there is an existing binding record for the same provider # Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = ( account_integrate: AccountIntegrate | None = db.session.scalar(
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
) )
if account_integrate: if account_integrate:
@@ -416,7 +422,9 @@ class AccountService:
def update_account_email(account: Account, email: str) -> Account: def update_account_email(account: Account, email: str) -> Account:
"""Update account email""" """Update account email"""
account.email = email account.email = email
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first() account_integrate = db.session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate: if account_integrate:
db.session.delete(account_integrate) db.session.delete(account_integrate)
db.session.add(account) db.session.add(account)
@@ -818,7 +826,7 @@ class AccountService:
) )
) )
account = db.session.query(Account).where(Account.email == email).first() account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account: if not account:
return None return None
@@ -1018,7 +1026,7 @@ class AccountService:
@staticmethod @staticmethod
def check_email_unique(email: str) -> bool: def check_email_unique(email: str) -> bool:
return db.session.query(Account).filter_by(email=email).first() is None return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService: class TenantService:

View File

@@ -173,9 +173,7 @@ class TestAccountService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock mock_db_dependencies["db"].session.scalar.return_value = mock_account
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_password_dependencies["compare_password"].return_value = True mock_password_dependencies["compare_password"].return_value = True
@@ -188,9 +186,7 @@ class TestAccountService:
def test_authenticate_account_not_found(self, mock_db_dependencies): def test_authenticate_account_not_found(self, mock_db_dependencies):
"""Test authentication when account does not exist.""" """Test authentication when account does not exist."""
# Setup smart database query mock - no matching results mock_db_dependencies["db"].session.scalar.return_value = None
query_results = {("Account", "email", "notfound@example.com"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Execute test and verify exception # Execute test and verify exception
self._assert_exception_raised( self._assert_exception_raised(
@@ -202,9 +198,7 @@ class TestAccountService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned") mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock mock_db_dependencies["db"].session.scalar.return_value = mock_account
query_results = {("Account", "email", "banned@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Execute test and verify exception # Execute test and verify exception
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password") self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
@@ -214,9 +208,7 @@ class TestAccountService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock mock_db_dependencies["db"].session.scalar.return_value = mock_account
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_password_dependencies["compare_password"].return_value = False mock_password_dependencies["compare_password"].return_value = False
@@ -230,9 +222,7 @@ class TestAccountService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending") mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
# Setup smart database query mock mock_db_dependencies["db"].session.scalar.return_value = mock_account
query_results = {("Account", "email", "pending@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_password_dependencies["compare_password"].return_value = True mock_password_dependencies["compare_password"].return_value = True
@@ -422,12 +412,8 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock() mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Setup smart database query mock mock_db_dependencies["db"].session.get.return_value = mock_account
query_results = { mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock datetime # Mock datetime
with patch("services.account_service.datetime") as mock_datetime: with patch("services.account_service.datetime") as mock_datetime:
@@ -444,9 +430,7 @@ class TestAccountService:
def test_load_user_not_found(self, mock_db_dependencies): def test_load_user_not_found(self, mock_db_dependencies):
"""Test user loading when user does not exist.""" """Test user loading when user does not exist."""
# Setup smart database query mock - no matching results mock_db_dependencies["db"].session.get.return_value = None
query_results = {("Account", "id", "non-existent-user"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Execute test # Execute test
result = AccountService.load_user("non-existent-user") result = AccountService.load_user("non-existent-user")
@@ -459,9 +443,7 @@ class TestAccountService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned") mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock mock_db_dependencies["db"].session.get.return_value = mock_account
query_results = {("Account", "id", "user-123"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Execute test and verify exception # Execute test and verify exception
self._assert_exception_raised( self._assert_exception_raised(
@@ -476,13 +458,9 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False) mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
# Setup smart database query mock for complex scenario mock_db_dependencies["db"].session.get.return_value = mock_account
query_results = { # First scalar: current tenant (None), second scalar: available tenant
("Account", "id", "user-123"): mock_account, mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock datetime # Mock datetime
with patch("services.account_service.datetime") as mock_datetime: with patch("services.account_service.datetime") as mock_datetime:
@@ -503,13 +481,9 @@ class TestAccountService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock for no tenants scenario mock_db_dependencies["db"].session.get.return_value = mock_account
query_results = { # First scalar: current tenant (None), second scalar: available tenant (None)
("Account", "id", "user-123"): mock_account, mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock datetime # Mock datetime
with patch("services.account_service.datetime") as mock_datetime: with patch("services.account_service.datetime") as mock_datetime: