refactor: select in account_service (TenantService class) (#34499)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo
2026-04-03 13:03:45 +02:00
committed by GitHub
parent 83d4176785
commit 06dde4f503
2 changed files with 86 additions and 99 deletions

View File

@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session
@@ -1069,11 +1069,11 @@ class TenantService:
@staticmethod
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not"""
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if available_ta:
@@ -1104,7 +1104,11 @@ class TenantService:
logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if ta:
ta.role = TenantAccountRole(role)
else:
@@ -1119,11 +1123,12 @@ class TenantService:
@staticmethod
def get_join_tenants(account: Account) -> list[Tenant]:
"""Get account join tenants"""
return (
db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.all()
return list(
db.session.scalars(
select(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
).all()
)
@staticmethod
@@ -1133,7 +1138,11 @@ class TenantService:
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if ta:
tenant.role = ta.role
else:
@@ -1148,23 +1157,25 @@ class TenantService:
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")
tenant_account_join = (
db.session.query(TenantAccountJoin)
tenant_account_join = db.session.scalar(
select(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.where(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
)
.first()
.limit(1)
)
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
db.session.query(TenantAccountJoin).where(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
db.session.execute(
update(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
.values(current=False)
)
tenant_account_join.current = True
# Set the current tenant for the account
account.set_tenant_id(tenant_account_join.tenant_id)
@@ -1173,8 +1184,8 @@ class TenantService:
@staticmethod
def get_tenant_members(tenant: Tenant) -> list[Account]:
"""Get tenant members"""
query = (
db.session.query(Account, TenantAccountJoin.role)
stmt = (
select(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(TenantAccountJoin.tenant_id == tenant.id)
@@ -1183,7 +1194,7 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in query:
for account, role in db.session.execute(stmt):
account.role = role
updated_accounts.append(account)
@@ -1192,8 +1203,8 @@ class TenantService:
@staticmethod
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
"""Get dataset admin members"""
query = (
db.session.query(Account, TenantAccountJoin.role)
stmt = (
select(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(TenantAccountJoin.tenant_id == tenant.id)
@@ -1203,7 +1214,7 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in query:
for account, role in db.session.execute(stmt):
account.role = role
updated_accounts.append(account)
@@ -1216,26 +1227,31 @@ class TenantService:
raise ValueError("all roles must be TenantAccountRole")
return (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
.first()
db.session.scalar(
select(TenantAccountJoin)
.where(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles]),
)
.limit(1)
)
is not None
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first()
.limit(1)
)
return TenantAccountRole(join.role) if join else None
@staticmethod
def get_tenant_count() -> int:
"""Get tenant count"""
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
@@ -1252,7 +1268,11 @@ class TenantService:
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
ta_operator = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
.limit(1)
)
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.")
@@ -1270,7 +1290,11 @@ class TenantService:
TenantService.check_member_permission(tenant, operator, account, "remove")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta:
raise MemberNotInTenantError("Member not in tenant.")
@@ -1285,7 +1309,12 @@ class TenantService:
should_delete_account = False
if account.status == AccountStatus.PENDING:
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count()
remaining_joins = (
db.session.scalar(
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
)
or 0
)
if remaining_joins == 0:
db.session.delete(account)
should_delete_account = True
@@ -1320,8 +1349,10 @@ class TenantService:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
target_member_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
)
if not target_member_join:
@@ -1332,8 +1363,10 @@ class TenantService:
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if current_owner_join:
current_owner_join.role = TenantAccountRole.ADMIN

View File

@@ -556,12 +556,8 @@ class TestTenantService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock - no existing tenant joins
query_results = {
("TenantAccountJoin", "account_id", "user-123"): None,
("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock scalar to return None (no existing tenant joins)
mock_db_dependencies["db"].session.scalar.return_value = None
# Setup external service mocks
mock_external_service_dependencies[
@@ -650,9 +646,8 @@ class TestTenantService:
mock_tenant.id = "tenant-456"
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock - no existing member
query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock scalar to return None (no existing member)
mock_db_dependencies["db"].session.scalar.return_value = None
# Mock database operations
mock_db_dependencies["db"].session.add = MagicMock()
@@ -693,16 +688,8 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner"
)
query_mock_permission = MagicMock()
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join
query_mock_ta = MagicMock()
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
query_mock_count = MagicMock()
query_mock_count.filter_by.return_value.count.return_value = 0
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count]
# scalar calls: permission check, ta lookup, remaining count
mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 0]
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
mock_sync.return_value = True
@@ -741,17 +728,8 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner"
)
query_mock_permission = MagicMock()
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join
query_mock_ta = MagicMock()
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
# Remaining join count = 1 (still in another workspace)
query_mock_count = MagicMock()
query_mock_count.filter_by.return_value.count.return_value = 1
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count]
# scalar calls: permission check, ta lookup, remaining count = 1
mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 1]
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
mock_sync.return_value = True
@@ -781,13 +759,8 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner"
)
query_mock_permission = MagicMock()
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join
query_mock_ta = MagicMock()
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta]
# scalar calls: permission check, ta lookup (no count needed for active member)
mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta]
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
mock_sync.return_value = True
@@ -810,13 +783,8 @@ class TestTenantService:
# Mock the complex query in switch_tenant method
with patch("services.account_service.db") as mock_db:
# Mock the join query that returns the tenant_account_join
mock_query = MagicMock()
mock_where = MagicMock()
mock_where.first.return_value = mock_tenant_join
mock_query.where.return_value = mock_where
mock_query.join.return_value = mock_query
mock_db.session.query.return_value = mock_query
# Mock scalar for the join query
mock_db.session.scalar.return_value = mock_tenant_join
# Execute test
TenantService.switch_tenant(mock_account, "tenant-456")
@@ -851,20 +819,8 @@ class TestTenantService:
# Mock the database queries in update_member_role method
with patch("services.account_service.db") as mock_db:
# Mock the first query for operator permission check
mock_query1 = MagicMock()
mock_filter1 = MagicMock()
mock_filter1.first.return_value = mock_operator_join
mock_query1.filter_by.return_value = mock_filter1
# Mock the second query for target member
mock_query2 = MagicMock()
mock_filter2 = MagicMock()
mock_filter2.first.return_value = mock_target_join
mock_query2.filter_by.return_value = mock_filter2
# Make the query method return different mocks for different calls
mock_db.session.query.side_effect = [mock_query1, mock_query2]
# scalar calls: permission check, target member lookup
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join]
# Execute test
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
@@ -886,9 +842,7 @@ class TestTenantService:
tenant_id="tenant-456", account_id="operator-123", role="owner"
)
# Setup smart database query mock
query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join
# Execute test - should not raise exception
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")