refactor: use sessionmaker().begin() in web and mcp controllers (#34281)

This commit is contained in:
Desel72
2026-03-31 17:26:37 +03:00
committed by GitHub
parent cf50d7c7b5
commit 2c8b47ce44
5 changed files with 19 additions and 23 deletions

View File

@@ -4,7 +4,7 @@ from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.schema import register_schema_model
from controllers.mcp import mcp_ns
@@ -67,7 +67,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)
@@ -189,7 +189,7 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
@@ -229,9 +229,7 @@ class MCPAppApi(Resource):
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
# Commit the session before creating end user to avoid transaction conflicts
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

View File

@@ -2,7 +2,7 @@ from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
@@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource):
query = ConversationListQuery.model_validate(raw_args)
try:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@@ -4,7 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
@@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
@@ -180,18 +180,17 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
self._update_existing_account(account, password_hashed, salt)
else:
raise AuthenticationFailedError()
return {"result": "success"}
def _update_existing_account(self, account: Account, password_hashed, salt, session):
def _update_existing_account(self, account: Account, password_hashed, salt):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()

View File

@@ -6,7 +6,7 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from constants import HEADER_NAME_APP_CODE
@@ -49,7 +49,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:

View File

@@ -37,7 +37,7 @@ class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.sessionmaker")
def test_should_normalize_email_before_sending(
self,
mock_session_cls,
@@ -51,7 +51,7 @@ class TestForgotPasswordSendEmailApi:
mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
@@ -153,7 +153,7 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback(
@@ -169,7 +169,7 @@ class TestForgotPasswordResetApi:
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
@@ -190,7 +190,7 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@@ -208,7 +208,7 @@ class TestForgotPasswordResetApi:
account = MagicMock()
mock_get_account.return_value = account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
@@ -231,4 +231,3 @@ class TestForgotPasswordResetApi:
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
mock_session.commit.assert_called_once()