diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 58ec76243b9..3c59535a48f 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -4,7 +4,7 @@ from flask import Response from flask_restx import Resource from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from controllers.common.schema import register_schema_model from controllers.mcp import mcp_ns @@ -67,7 +67,7 @@ class MCPAppApi(Resource): request_id: Union[int, str] | None = args.id mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True)) - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Get MCP server and app mcp_server, app = self._get_mcp_server_and_app(server_code, session) self._validate_server_status(mcp_server) @@ -189,7 +189,7 @@ class MCPAppApi(Resource): def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None: """Get end user - manages its own database session""" - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: return ( session.query(EndUser) .where(EndUser.tenant_id == tenant_id) @@ -229,9 +229,7 @@ class MCPAppApi(Resource): if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): client_info = mcp_request.root.params.clientInfo client_name = f"{client_info.name}@{client_info.version}" - # Commit the session before creating end user to avoid transaction conflicts - session.commit() - with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin(): + with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session: end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session) return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index e76649495a0..d5baa5fb7d1 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -2,7 +2,7 @@ from typing import Literal from flask import request from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models @@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource): query = ConversationListQuery.model_validate(raw_args) try: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 91d206f7270..d69571cc9cf 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -4,7 +4,7 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) token = None if account is None: @@ -180,18 +180,17 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: - self._update_existing_account(account, password_hashed, salt, session) + self._update_existing_account(account, password_hashed, salt) else: raise AuthenticationFailedError() return {"result": "success"} - def _update_existing_account(self, account: Account, password_hashed, salt, session): + def _update_existing_account(self, account: Account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() - session.commit() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 152137f39c8..654951a1aa5 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -6,7 +6,7 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import request from flask_restx import Resource from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from constants import HEADER_NAME_APP_CODE @@ -49,7 +49,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None): decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: app_model = session.scalar(select(App).where(App.id == app_id)) site = session.scalar(select(Site).where(Site.code == app_code)) if not app_model: diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 19057726c34..04ad143103b 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,7 +37,7 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, mock_session_cls, @@ -51,7 +51,7 @@ class TestForgotPasswordSendEmailApi: mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with app.test_request_context( @@ -153,7 +153,7 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.sessionmaker") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( @@ -169,7 +169,7 @@ class TestForgotPasswordResetApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with app.test_request_context( @@ -190,7 +190,7 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.sessionmaker") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -208,7 +208,7 @@ class TestForgotPasswordResetApi: account = MagicMock() mock_get_account.return_value = account mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with app.test_request_context( @@ -231,4 +231,3 @@ class TestForgotPasswordResetApi: assert account.password == expected_password expected_salt = base64.b64encode(b"0123456789abcdef").decode() assert account.password_salt == expected_salt - mock_session.commit.assert_called_once()