refactor: use sessionmaker().begin() in web and mcp controllers (#34281)

This commit is contained in:
Desel72
2026-03-31 17:26:37 +03:00
committed by GitHub
parent cf50d7c7b5
commit 2c8b47ce44
5 changed files with 19 additions and 23 deletions

View File

@@ -4,7 +4,7 @@ from flask import Response
from flask_restx import Resource from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.mcp import mcp_ns from controllers.mcp import mcp_ns
@@ -67,7 +67,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str] | None = args.id request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True)) mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with Session(db.engine, expire_on_commit=False) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get MCP server and app # Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session) mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server) self._validate_server_status(mcp_server)
@@ -189,7 +189,7 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None: def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session""" """Get end user - manages its own database session"""
with Session(db.engine, expire_on_commit=False) as session, session.begin(): with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return ( return (
session.query(EndUser) session.query(EndUser)
.where(EndUser.tenant_id == tenant_id) .where(EndUser.tenant_id == tenant_id)
@@ -229,9 +229,7 @@ class MCPAppApi(Resource):
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}" client_name = f"{client_info.name}@{client_info.version}"
# Commit the session before creating end user to avoid transaction conflicts with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session) end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id) return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

View File

@@ -2,7 +2,7 @@ from typing import Literal
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource):
query = ConversationListQuery.model_validate(raw_args) query = ConversationListQuery.model_validate(raw_args)
try: try:
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
pagination = WebConversationService.pagination_by_last_id( pagination = WebConversationService.pagination_by_last_id(
session=session, session=session,
app_model=app_model, app_model=app_model,

View File

@@ -4,7 +4,7 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.auth.error import ( from controllers.console.auth.error import (
@@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None token = None
if account is None: if account is None:
@@ -180,18 +180,17 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "") email = reset_data.get("email", "")
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session) account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
self._update_existing_account(account, password_hashed, salt, session) self._update_existing_account(account, password_hashed, salt)
else: else:
raise AuthenticationFailedError() raise AuthenticationFailedError()
return {"result": "success"} return {"result": "success"}
def _update_existing_account(self, account: Account, password_hashed, salt, session): def _update_existing_account(self, account: Account, password_hashed, salt):
# Update existing account credentials # Update existing account credentials
account.password = base64.b64encode(password_hashed).decode() account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode() account.password_salt = base64.b64encode(salt).decode()
session.commit()

View File

@@ -6,7 +6,7 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from constants import HEADER_NAME_APP_CODE from constants import HEADER_NAME_APP_CODE
@@ -49,7 +49,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_id = decoded.get("app_id") app_id = decoded.get("app_id")
with Session(db.engine, expire_on_commit=False) as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_model = session.scalar(select(App).where(App.id == app_id)) app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code)) site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model: if not app_model:

View File

@@ -37,7 +37,7 @@ class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.Session") @patch("controllers.web.forgot_password.sessionmaker")
def test_should_normalize_email_before_sending( def test_should_normalize_email_before_sending(
self, self,
mock_session_cls, mock_session_cls,
@@ -51,7 +51,7 @@ class TestForgotPasswordSendEmailApi:
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123" mock_send_mail.return_value = "token-123"
mock_session = MagicMock() mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context( with app.test_request_context(
@@ -153,7 +153,7 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.Session") @patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback( def test_should_fetch_account_with_fallback(
@@ -169,7 +169,7 @@ class TestForgotPasswordResetApi:
mock_account = MagicMock() mock_account = MagicMock()
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
mock_session = MagicMock() mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context( with app.test_request_context(
@@ -190,7 +190,7 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session") @patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@@ -208,7 +208,7 @@ class TestForgotPasswordResetApi:
account = MagicMock() account = MagicMock()
mock_get_account.return_value = account mock_get_account.return_value = account
mock_session = MagicMock() mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context( with app.test_request_context(
@@ -231,4 +231,3 @@ class TestForgotPasswordResetApi:
assert account.password == expected_password assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode() expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt assert account.password_salt == expected_salt
mock_session.commit.assert_called_once()