refactor: select in account_service (TenantService class) (#34499)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo
2026-04-03 13:03:45 +02:00
committed by GitHub
parent 83d4176785
commit 06dde4f503
2 changed files with 86 additions and 99 deletions

View File

@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, TypedDict, cast from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -1069,11 +1069,11 @@ class TenantService:
@staticmethod @staticmethod
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not""" """Check if user have a workspace or not"""
available_ta = ( available_ta = db.session.scalar(
db.session.query(TenantAccountJoin) select(TenantAccountJoin)
.filter_by(account_id=account.id) .where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc()) .order_by(TenantAccountJoin.id.asc())
.first() .limit(1)
) )
if available_ta: if available_ta:
@@ -1104,7 +1104,11 @@ class TenantService:
logger.error("Tenant %s has already an owner.", tenant.id) logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.") raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if ta: if ta:
ta.role = TenantAccountRole(role) ta.role = TenantAccountRole(role)
else: else:
@@ -1119,11 +1123,12 @@ class TenantService:
@staticmethod @staticmethod
def get_join_tenants(account: Account) -> list[Tenant]: def get_join_tenants(account: Account) -> list[Tenant]:
"""Get account join tenants""" """Get account join tenants"""
return ( return list(
db.session.query(Tenant) db.session.scalars(
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) select(Tenant)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.all() .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
).all()
) )
@staticmethod @staticmethod
@@ -1133,7 +1138,11 @@ class TenantService:
if not tenant: if not tenant:
raise TenantNotFoundError("Tenant not found.") raise TenantNotFoundError("Tenant not found.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if ta: if ta:
tenant.role = ta.role tenant.role = ta.role
else: else:
@@ -1148,23 +1157,25 @@ class TenantService:
if tenant_id is None: if tenant_id is None:
raise ValueError("Tenant ID must be provided.") raise ValueError("Tenant ID must be provided.")
tenant_account_join = ( tenant_account_join = db.session.scalar(
db.session.query(TenantAccountJoin) select(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.where( .where(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL, Tenant.status == TenantStatus.NORMAL,
) )
.first() .limit(1)
) )
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else: else:
db.session.query(TenantAccountJoin).where( db.session.execute(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id update(TenantAccountJoin)
).update({"current": False}) .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
.values(current=False)
)
tenant_account_join.current = True tenant_account_join.current = True
# Set the current tenant for the account # Set the current tenant for the account
account.set_tenant_id(tenant_account_join.tenant_id) account.set_tenant_id(tenant_account_join.tenant_id)
@@ -1173,8 +1184,8 @@ class TenantService:
@staticmethod @staticmethod
def get_tenant_members(tenant: Tenant) -> list[Account]: def get_tenant_members(tenant: Tenant) -> list[Account]:
"""Get tenant members""" """Get tenant members"""
query = ( stmt = (
db.session.query(Account, TenantAccountJoin.role) select(Account, TenantAccountJoin.role)
.select_from(Account) .select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(TenantAccountJoin.tenant_id == tenant.id) .where(TenantAccountJoin.tenant_id == tenant.id)
@@ -1183,7 +1194,7 @@ class TenantService:
# Initialize an empty list to store the updated accounts # Initialize an empty list to store the updated accounts
updated_accounts = [] updated_accounts = []
for account, role in query: for account, role in db.session.execute(stmt):
account.role = role account.role = role
updated_accounts.append(account) updated_accounts.append(account)
@@ -1192,8 +1203,8 @@ class TenantService:
@staticmethod @staticmethod
def get_dataset_operator_members(tenant: Tenant) -> list[Account]: def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
"""Get dataset admin members""" """Get dataset admin members"""
query = ( stmt = (
db.session.query(Account, TenantAccountJoin.role) select(Account, TenantAccountJoin.role)
.select_from(Account) .select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(TenantAccountJoin.tenant_id == tenant.id) .where(TenantAccountJoin.tenant_id == tenant.id)
@@ -1203,7 +1214,7 @@ class TenantService:
# Initialize an empty list to store the updated accounts # Initialize an empty list to store the updated accounts
updated_accounts = [] updated_accounts = []
for account, role in query: for account, role in db.session.execute(stmt):
account.role = role account.role = role
updated_accounts.append(account) updated_accounts.append(account)
@@ -1216,26 +1227,31 @@ class TenantService:
raise ValueError("all roles must be TenantAccountRole") raise ValueError("all roles must be TenantAccountRole")
return ( return (
db.session.query(TenantAccountJoin) db.session.scalar(
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])) select(TenantAccountJoin)
.first() .where(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles]),
)
.limit(1)
)
is not None is not None
) )
@staticmethod @staticmethod
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant""" """Get the role of the current account for a given tenant"""
join = ( join = db.session.scalar(
db.session.query(TenantAccountJoin) select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first() .limit(1)
) )
return TenantAccountRole(join.role) if join else None return TenantAccountRole(join.role) if join else None
@staticmethod @staticmethod
def get_tenant_count() -> int: def get_tenant_count() -> int:
"""Get tenant count""" """Get tenant count"""
return cast(int, db.session.query(func.count(Tenant.id)).scalar()) return cast(int, db.session.scalar(select(func.count(Tenant.id))))
@staticmethod @staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str): def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
@@ -1252,7 +1268,11 @@ class TenantService:
if operator.id == member.id: if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first() ta_operator = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
.limit(1)
)
if not ta_operator or ta_operator.role not in perms[action]: if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.") raise NoPermissionError(f"No permission to {action} member.")
@@ -1270,7 +1290,11 @@ class TenantService:
TenantService.check_member_permission(tenant, operator, account, "remove") TenantService.check_member_permission(tenant, operator, account, "remove")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta: if not ta:
raise MemberNotInTenantError("Member not in tenant.") raise MemberNotInTenantError("Member not in tenant.")
@@ -1285,7 +1309,12 @@ class TenantService:
should_delete_account = False should_delete_account = False
if account.status == AccountStatus.PENDING: if account.status == AccountStatus.PENDING:
# autoflush flushes ta deletion before this query, so 0 means no remaining joins # autoflush flushes ta deletion before this query, so 0 means no remaining joins
remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count() remaining_joins = (
db.session.scalar(
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
)
or 0
)
if remaining_joins == 0: if remaining_joins == 0:
db.session.delete(account) db.session.delete(account)
should_delete_account = True should_delete_account = True
@@ -1320,8 +1349,10 @@ class TenantService:
"""Update member role""" """Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update") TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = ( target_member_join = db.session.scalar(
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first() select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
) )
if not target_member_join: if not target_member_join:
@@ -1332,8 +1363,10 @@ class TenantService:
if new_role == "owner": if new_role == "owner":
# Find the current owner and change their role to 'admin' # Find the current owner and change their role to 'admin'
current_owner_join = ( current_owner_join = db.session.scalar(
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
) )
if current_owner_join: if current_owner_join:
current_owner_join.role = TenantAccountRole.ADMIN current_owner_join.role = TenantAccountRole.ADMIN

View File

@@ -556,12 +556,8 @@ class TestTenantService:
# Setup test data # Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock - no existing tenant joins # Mock scalar to return None (no existing tenant joins)
query_results = { mock_db_dependencies["db"].session.scalar.return_value = None
("TenantAccountJoin", "account_id", "user-123"): None,
("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Setup external service mocks # Setup external service mocks
mock_external_service_dependencies[ mock_external_service_dependencies[
@@ -650,9 +646,8 @@ class TestTenantService:
mock_tenant.id = "tenant-456" mock_tenant.id = "tenant-456"
mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock - no existing member # Mock scalar to return None (no existing member)
query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None} mock_db_dependencies["db"].session.scalar.return_value = None
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock database operations # Mock database operations
mock_db_dependencies["db"].session.add = MagicMock() mock_db_dependencies["db"].session.add = MagicMock()
@@ -693,16 +688,8 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner" tenant_id="tenant-456", account_id="operator-123", role="owner"
) )
query_mock_permission = MagicMock() # scalar calls: permission check, ta lookup, remaining count
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 0]
query_mock_ta = MagicMock()
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
query_mock_count = MagicMock()
query_mock_count.filter_by.return_value.count.return_value = 0
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count]
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
mock_sync.return_value = True mock_sync.return_value = True
@@ -741,17 +728,8 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner" tenant_id="tenant-456", account_id="operator-123", role="owner"
) )
query_mock_permission = MagicMock() # scalar calls: permission check, ta lookup, remaining count = 1
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 1]
query_mock_ta = MagicMock()
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
# Remaining join count = 1 (still in another workspace)
query_mock_count = MagicMock()
query_mock_count.filter_by.return_value.count.return_value = 1
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count]
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
mock_sync.return_value = True mock_sync.return_value = True
@@ -781,13 +759,8 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner" tenant_id="tenant-456", account_id="operator-123", role="owner"
) )
query_mock_permission = MagicMock() # scalar calls: permission check, ta lookup (no count needed for active member)
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta]
query_mock_ta = MagicMock()
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta]
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
mock_sync.return_value = True mock_sync.return_value = True
@@ -810,13 +783,8 @@ class TestTenantService:
# Mock the complex query in switch_tenant method # Mock the complex query in switch_tenant method
with patch("services.account_service.db") as mock_db: with patch("services.account_service.db") as mock_db:
# Mock the join query that returns the tenant_account_join # Mock scalar for the join query
mock_query = MagicMock() mock_db.session.scalar.return_value = mock_tenant_join
mock_where = MagicMock()
mock_where.first.return_value = mock_tenant_join
mock_query.where.return_value = mock_where
mock_query.join.return_value = mock_query
mock_db.session.query.return_value = mock_query
# Execute test # Execute test
TenantService.switch_tenant(mock_account, "tenant-456") TenantService.switch_tenant(mock_account, "tenant-456")
@@ -851,20 +819,8 @@ class TestTenantService:
# Mock the database queries in update_member_role method # Mock the database queries in update_member_role method
with patch("services.account_service.db") as mock_db: with patch("services.account_service.db") as mock_db:
# Mock the first query for operator permission check # scalar calls: permission check, target member lookup
mock_query1 = MagicMock() mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join]
mock_filter1 = MagicMock()
mock_filter1.first.return_value = mock_operator_join
mock_query1.filter_by.return_value = mock_filter1
# Mock the second query for target member
mock_query2 = MagicMock()
mock_filter2 = MagicMock()
mock_filter2.first.return_value = mock_target_join
mock_query2.filter_by.return_value = mock_filter2
# Make the query method return different mocks for different calls
mock_db.session.query.side_effect = [mock_query1, mock_query2]
# Execute test # Execute test
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator) TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
@@ -886,9 +842,7 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner" tenant_id="tenant-456", account_id="operator-123", role="owner"
) )
# Setup smart database query mock mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join
query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Execute test - should not raise exception # Execute test - should not raise exception
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add") TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")