diff --git a/api/services/account_service.py b/api/services/account_service.py index d4ba7520bdf..d02f244428e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter -from sqlalchemy import func, select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session @@ -1392,10 +1392,10 @@ class RegisterService: db.session.add(dify_setup) db.session.commit() except Exception as e: - db.session.query(DifySetup).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Account).delete() - db.session.query(Tenant).delete() + db.session.execute(delete(DifySetup)) + db.session.execute(delete(TenantAccountJoin)) + db.session.execute(delete(Account)) + db.session.execute(delete(Tenant)) db.session.commit() logger.exception("Setup account failed, email: %s, name: %s", email, name) @@ -1496,7 +1496,11 @@ class RegisterService: TenantService.switch_tenant(account, tenant.id) else: TenantService.check_member_permission(tenant, inviter, account, "add") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -1553,21 +1557,18 @@ class RegisterService: if not invitation_data: return None - tenant = ( - db.session.query(Tenant) - .where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") - .first() + tenant = db.session.scalar( + select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1) ) if not tenant: return None - tenant_account = ( - db.session.query(Account, TenantAccountJoin.role) + tenant_account = db.session.execute( + select(Account, TenantAccountJoin.role) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) - .first() - ) + ).first() if not tenant_account: return None diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index af86949012b..cc64159c5f3 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1034,7 +1034,7 @@ class TestRegisterService: ) # Verify rollback operations were called - mock_db_dependencies["db"].session.query.assert_called() + mock_db_dependencies["db"].session.execute.assert_called() # ==================== Registration Tests ==================== @@ -1599,10 +1599,8 @@ class TestRegisterService: mock_session_class.return_value.__exit__.return_value = None mock_lookup.return_value = mock_existing_account - # Mock the db.session.query for TenantAccountJoin - mock_db_query = MagicMock() - mock_db_query.filter_by.return_value.first.return_value = None # No existing member - mock_db_dependencies["db"].session.query.return_value = mock_db_query + # Mock scalar for TenantAccountJoin lookup - no existing member + mock_db_dependencies["db"].session.scalar.return_value = None # Mock TenantService methods with ( @@ -1777,14 +1775,9 @@ class TestRegisterService: } mock_get_invitation_by_token.return_value = invitation_data - # Mock database queries - complex query mocking - mock_query1 = MagicMock() - mock_query1.where.return_value.first.return_value = mock_tenant - - mock_query2 = MagicMock() - mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal") - - mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] + # Mock scalar for tenant lookup, execute for account+role lookup + mock_db_dependencies["db"].session.scalar.return_value = mock_tenant + mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal") # Execute test result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") @@ -1816,10 +1809,8 @@ class TestRegisterService: } mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() - # Mock database queries - no tenant found - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = None - mock_db_dependencies["db"].session.query.return_value = mock_query + # Mock scalar for tenant lookup - not found + mock_db_dependencies["db"].session.scalar.return_value = None # Execute test result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") @@ -1842,14 +1833,9 @@ class TestRegisterService: } mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() - # Mock database queries - mock_query1 = MagicMock() - mock_query1.filter.return_value.first.return_value = mock_tenant - - mock_query2 = MagicMock() - mock_query2.join.return_value.where.return_value.first.return_value = None # No account found - - mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] + # Mock scalar for tenant, execute for account+role + mock_db_dependencies["db"].session.scalar.return_value = mock_tenant + mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found # Execute test result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") @@ -1875,14 +1861,9 @@ class TestRegisterService: } mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() - # Mock database queries - mock_query1 = MagicMock() - mock_query1.filter.return_value.first.return_value = mock_tenant - - mock_query2 = MagicMock() - mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal") - - mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] + # Mock scalar for tenant, execute for account+role + mock_db_dependencies["db"].session.scalar.return_value = mock_tenant + mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal") # Execute test result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")