diff --git a/api/services/account_service.py b/api/services/account_service.py index 29b14447305..d4ba7520bdf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -144,22 +144,26 @@ class AccountService: @staticmethod def load_user(user_id: str) -> None | Account: - account = db.session.query(Account).filter_by(id=user_id).first() + account = db.session.get(Account, user_id) if not account: return None if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") - current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() + current_tenant = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True) + .limit(1) + ) if current_tenant: account.set_tenant_id(current_tenant.tenant_id) else: - available_ta = ( - db.session.query(TenantAccountJoin) - .filter_by(account_id=account.id) + available_ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) - .first() + .limit(1) ) if not available_ta: return None @@ -195,7 +199,7 @@ class AccountService: def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: """authenticate account with email and password""" - account = db.session.query(Account).filter_by(email=email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: raise AccountPasswordError("Invalid email or password.") @@ -371,8 +375,10 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: AccountIntegrate | None = ( - db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + account_integrate: AccountIntegrate | None = db.session.scalar( + select(AccountIntegrate) + .where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider) + .limit(1) ) if account_integrate: @@ -416,7 +422,9 @@ class AccountService: def update_account_email(account: Account, email: str) -> Account: """Update account email""" account.email = email - account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first() + account_integrate = db.session.scalar( + select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1) + ) if account_integrate: db.session.delete(account_integrate) db.session.add(account) @@ -818,7 +826,7 @@ class AccountService: ) ) - account = db.session.query(Account).where(Account.email == email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: return None @@ -1018,7 +1026,7 @@ class AccountService: @staticmethod def check_email_unique(email: str) -> bool: - return db.session.query(Account).filter_by(email=email).first() is None + return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None class TenantService: diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index dcd6785464c..af86949012b 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -173,9 +173,7 @@ class TestAccountService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() - # Setup smart database query mock - query_results = {("Account", "email", "test@example.com"): mock_account} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.scalar.return_value = mock_account mock_password_dependencies["compare_password"].return_value = True @@ -188,9 +186,7 @@ class TestAccountService: def test_authenticate_account_not_found(self, mock_db_dependencies): """Test authentication when account does not exist.""" - # Setup smart database query mock - no matching results - query_results = {("Account", "email", "notfound@example.com"): None} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.scalar.return_value = None # Execute test and verify exception self._assert_exception_raised( @@ -202,9 +198,7 @@ class TestAccountService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned") - # Setup smart database query mock - query_results = {("Account", "email", "banned@example.com"): mock_account} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.scalar.return_value = mock_account # Execute test and verify exception self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password") @@ -214,9 +208,7 @@ class TestAccountService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() - # Setup smart database query mock - query_results = {("Account", "email", "test@example.com"): mock_account} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.scalar.return_value = mock_account mock_password_dependencies["compare_password"].return_value = False @@ -230,9 +222,7 @@ class TestAccountService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending") - # Setup smart database query mock - query_results = {("Account", "email", "pending@example.com"): mock_account} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.scalar.return_value = mock_account mock_password_dependencies["compare_password"].return_value = True @@ -422,12 +412,8 @@ class TestAccountService: mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock() - # Setup smart database query mock - query_results = { - ("Account", "id", "user-123"): mock_account, - ("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join, - } - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.get.return_value = mock_account + mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join # Mock datetime with patch("services.account_service.datetime") as mock_datetime: @@ -444,9 +430,7 @@ class TestAccountService: def test_load_user_not_found(self, mock_db_dependencies): """Test user loading when user does not exist.""" - # Setup smart database query mock - no matching results - query_results = {("Account", "id", "non-existent-user"): None} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.get.return_value = None # Execute test result = AccountService.load_user("non-existent-user") @@ -459,9 +443,7 @@ class TestAccountService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned") - # Setup smart database query mock - query_results = {("Account", "id", "user-123"): mock_account} - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.get.return_value = mock_account # Execute test and verify exception self._assert_exception_raised( @@ -476,13 +458,9 @@ class TestAccountService: mock_account = TestAccountAssociatedDataFactory.create_account_mock() mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False) - # Setup smart database query mock for complex scenario - query_results = { - ("Account", "id", "user-123"): mock_account, - ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant - ("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant - } - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.get.return_value = mock_account + # First scalar: current tenant (None), second scalar: available tenant + mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant] # Mock datetime with patch("services.account_service.datetime") as mock_datetime: @@ -503,13 +481,9 @@ class TestAccountService: # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() - # Setup smart database query mock for no tenants scenario - query_results = { - ("Account", "id", "user-123"): mock_account, - ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant - ("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants - } - ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + mock_db_dependencies["db"].session.get.return_value = mock_account + # First scalar: current tenant (None), second scalar: available tenant (None) + mock_db_dependencies["db"].session.scalar.side_effect = [None, None] # Mock datetime with patch("services.account_service.datetime") as mock_datetime: