diff --git a/api/services/account_service.py b/api/services/account_service.py index d02f244428e..ee4c199df80 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session @@ -1069,11 +1069,11 @@ class TenantService: @staticmethod def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): """Check if user have a workspace or not""" - available_ta = ( - db.session.query(TenantAccountJoin) - .filter_by(account_id=account.id) + available_ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) - .first() + .limit(1) ) if available_ta: @@ -1104,7 +1104,11 @@ class TenantService: logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if ta: ta.role = TenantAccountRole(role) else: @@ -1119,11 +1123,12 @@ class TenantService: @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return ( - db.session.query(Tenant) - .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) - .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) - .all() + return list( + db.session.scalars( + select(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + ).all() ) @staticmethod @@ -1133,7 +1138,11 @@ class TenantService: if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if ta: tenant.role = ta.role else: @@ -1148,23 +1157,25 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = ( - db.session.query(TenantAccountJoin) + tenant_account_join = db.session.scalar( + select(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) .where( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id == tenant_id, Tenant.status == TenantStatus.NORMAL, ) - .first() + .limit(1) ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - db.session.query(TenantAccountJoin).where( - TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id - ).update({"current": False}) + db.session.execute( + update(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id) + .values(current=False) + ) tenant_account_join.current = True # Set the current tenant for the account account.set_tenant_id(tenant_account_join.tenant_id) @@ -1173,8 +1184,8 @@ class TenantService: @staticmethod def get_tenant_members(tenant: Tenant) -> list[Account]: """Get tenant members""" - query = ( - db.session.query(Account, TenantAccountJoin.role) + stmt = ( + select(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .where(TenantAccountJoin.tenant_id == tenant.id) @@ -1183,7 +1194,7 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in query: + for account, role in db.session.execute(stmt): account.role = role updated_accounts.append(account) @@ -1192,8 +1203,8 @@ class TenantService: @staticmethod def get_dataset_operator_members(tenant: Tenant) -> list[Account]: """Get dataset admin members""" - query = ( - db.session.query(Account, TenantAccountJoin.role) + stmt = ( + select(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .where(TenantAccountJoin.tenant_id == tenant.id) @@ -1203,7 +1214,7 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in query: + for account, role in db.session.execute(stmt): account.role = role updated_accounts.append(account) @@ -1216,26 +1227,31 @@ class TenantService: raise ValueError("all roles must be TenantAccountRole") return ( - db.session.query(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])) - .first() + db.session.scalar( + select(TenantAccountJoin) + .where( + TenantAccountJoin.tenant_id == tenant.id, + TenantAccountJoin.role.in_([role.value for role in roles]), + ) + .limit(1) + ) is not None ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" - join = ( - db.session.query(TenantAccountJoin) + join = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) - .first() + .limit(1) ) return TenantAccountRole(join.role) if join else None @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return cast(int, db.session.query(func.count(Tenant.id)).scalar()) + return cast(int, db.session.scalar(select(func.count(Tenant.id)))) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str): @@ -1252,7 +1268,11 @@ class TenantService: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first() + ta_operator = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id) + .limit(1) + ) if not ta_operator or ta_operator.role not in perms[action]: raise NoPermissionError(f"No permission to {action} member.") @@ -1270,7 +1290,11 @@ class TenantService: TenantService.check_member_permission(tenant, operator, account, "remove") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if not ta: raise MemberNotInTenantError("Member not in tenant.") @@ -1285,7 +1309,12 @@ class TenantService: should_delete_account = False if account.status == AccountStatus.PENDING: # autoflush flushes ta deletion before this query, so 0 means no remaining joins - remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count() + remaining_joins = ( + db.session.scalar( + select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id) + ) + or 0 + ) if remaining_joins == 0: db.session.delete(account) should_delete_account = True @@ -1320,8 +1349,10 @@ class TenantService: """Update member role""" TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first() + target_member_join = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id) + .limit(1) ) if not target_member_join: @@ -1332,8 +1363,10 @@ class TenantService: if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) ) if current_owner_join: current_owner_join.role = TenantAccountRole.ADMIN diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index cc64159c5f3..041929c5fa4 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -556,12 +556,8 @@ class TestTenantService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() - # Setup smart database query mock - no existing tenant joins - query_results = { - ("TenantAccountJoin", "account_id", "user-123"): None, - ("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check - } - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + # Mock scalar to return None (no existing tenant joins) + mock_db_dependencies["db"].session.scalar.return_value = None # Setup external service mocks mock_external_service_dependencies[ @@ -650,9 +646,8 @@ class TestTenantService: mock_tenant.id = "tenant-456" mock_account = TestAccountAssociatedDataFactory.create_account_mock() - # Setup smart database query mock - no existing member - query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + # Mock scalar to return None (no existing member) + mock_db_dependencies["db"].session.scalar.return_value = None # Mock database operations mock_db_dependencies["db"].session.add = MagicMock() @@ -693,16 +688,8 @@ class TestTenantService: tenant_id="tenant-456", account_id="operator-123", role="owner" ) - query_mock_permission = MagicMock() - query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join - - query_mock_ta = MagicMock() - query_mock_ta.filter_by.return_value.first.return_value = mock_ta - - query_mock_count = MagicMock() - query_mock_count.filter_by.return_value.count.return_value = 0 - - mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count] + # scalar calls: permission check, ta lookup, remaining count + mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 0] with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: mock_sync.return_value = True @@ -741,17 +728,8 @@ class TestTenantService: tenant_id="tenant-456", account_id="operator-123", role="owner" ) - query_mock_permission = MagicMock() - query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join - - query_mock_ta = MagicMock() - query_mock_ta.filter_by.return_value.first.return_value = mock_ta - - # Remaining join count = 1 (still in another workspace) - query_mock_count = MagicMock() - query_mock_count.filter_by.return_value.count.return_value = 1 - - mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count] + # scalar calls: permission check, ta lookup, remaining count = 1 + mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 1] with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: mock_sync.return_value = True @@ -781,13 +759,8 @@ class TestTenantService: tenant_id="tenant-456", account_id="operator-123", role="owner" ) - query_mock_permission = MagicMock() - query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join - - query_mock_ta = MagicMock() - query_mock_ta.filter_by.return_value.first.return_value = mock_ta - - mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta] + # scalar calls: permission check, ta lookup (no count needed for active member) + mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta] with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: mock_sync.return_value = True @@ -810,13 +783,8 @@ class TestTenantService: # Mock the complex query in switch_tenant method with patch("services.account_service.db") as mock_db: - # Mock the join query that returns the tenant_account_join - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.first.return_value = mock_tenant_join - mock_query.where.return_value = mock_where - mock_query.join.return_value = mock_query - mock_db.session.query.return_value = mock_query + # Mock scalar for the join query + mock_db.session.scalar.return_value = mock_tenant_join # Execute test TenantService.switch_tenant(mock_account, "tenant-456") @@ -851,20 +819,8 @@ class TestTenantService: # Mock the database queries in update_member_role method with patch("services.account_service.db") as mock_db: - # Mock the first query for operator permission check - mock_query1 = MagicMock() - mock_filter1 = MagicMock() - mock_filter1.first.return_value = mock_operator_join - mock_query1.filter_by.return_value = mock_filter1 - - # Mock the second query for target member - mock_query2 = MagicMock() - mock_filter2 = MagicMock() - mock_filter2.first.return_value = mock_target_join - mock_query2.filter_by.return_value = mock_filter2 - - # Make the query method return different mocks for different calls - mock_db.session.query.side_effect = [mock_query1, mock_query2] + # scalar calls: permission check, target member lookup + mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join] # Execute test TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator) @@ -886,9 +842,7 @@ class TestTenantService: tenant_id="tenant-456", account_id="operator-123", role="owner" ) - # Setup smart database query mock - query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join # Execute test - should not raise exception TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")