mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 17:09:23 +08:00
refactor: use sessionmaker().begin() in web and mcp controllers (#34281)
This commit is contained in:
@@ -4,7 +4,7 @@ from flask import Response
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.variables.input_entities import VariableEntity
|
from graphon.variables.input_entities import VariableEntity
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.mcp import mcp_ns
|
from controllers.mcp import mcp_ns
|
||||||
@@ -67,7 +67,7 @@ class MCPAppApi(Resource):
|
|||||||
request_id: Union[int, str] | None = args.id
|
request_id: Union[int, str] | None = args.id
|
||||||
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
|
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
# Get MCP server and app
|
# Get MCP server and app
|
||||||
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||||
self._validate_server_status(mcp_server)
|
self._validate_server_status(mcp_server)
|
||||||
@@ -189,7 +189,7 @@ class MCPAppApi(Resource):
|
|||||||
|
|
||||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||||
"""Get end user - manages its own database session"""
|
"""Get end user - manages its own database session"""
|
||||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
return (
|
return (
|
||||||
session.query(EndUser)
|
session.query(EndUser)
|
||||||
.where(EndUser.tenant_id == tenant_id)
|
.where(EndUser.tenant_id == tenant_id)
|
||||||
@@ -229,9 +229,7 @@ class MCPAppApi(Resource):
|
|||||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||||
client_info = mcp_request.root.params.clientInfo
|
client_info = mcp_request.root.params.clientInfo
|
||||||
client_name = f"{client_info.name}@{client_info.version}"
|
client_name = f"{client_info.name}@{client_info.version}"
|
||||||
# Commit the session before creating end user to avoid transaction conflicts
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
|
||||||
session.commit()
|
|
||||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
|
||||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||||
|
|
||||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource):
|
|||||||
query = ConversationListQuery.model_validate(raw_args)
|
query = ConversationListQuery.model_validate(raw_args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
pagination = WebConversationService.pagination_by_last_id(
|
pagination = WebConversationService.pagination_by_last_id(
|
||||||
session=session,
|
session=session,
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import secrets
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
@@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||||
token = None
|
token = None
|
||||||
if account is None:
|
if account is None:
|
||||||
@@ -180,18 +180,17 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
self._update_existing_account(account, password_hashed, salt, session)
|
self._update_existing_account(account, password_hashed, salt)
|
||||||
else:
|
else:
|
||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
def _update_existing_account(self, account: Account, password_hashed, salt, session):
|
def _update_existing_account(self, account: Account, password_hashed, salt):
|
||||||
# Update existing account credentials
|
# Update existing account credentials
|
||||||
account.password = base64.b64encode(password_hashed).decode()
|
account.password = base64.b64encode(password_hashed).decode()
|
||||||
account.password_salt = base64.b64encode(salt).decode()
|
account.password_salt = base64.b64encode(salt).decode()
|
||||||
session.commit()
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Concatenate, ParamSpec, TypeVar
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||||
|
|
||||||
from constants import HEADER_NAME_APP_CODE
|
from constants import HEADER_NAME_APP_CODE
|
||||||
@@ -49,7 +49,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
|||||||
decoded = PassportService().verify(tk)
|
decoded = PassportService().verify(tk)
|
||||||
app_code = decoded.get("app_code")
|
app_code = decoded.get("app_code")
|
||||||
app_id = decoded.get("app_id")
|
app_id = decoded.get("app_id")
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
app_model = session.scalar(select(App).where(App.id == app_id))
|
app_model = session.scalar(select(App).where(App.id == app_id))
|
||||||
site = session.scalar(select(Site).where(Site.code == app_code))
|
site = session.scalar(select(Site).where(Site.code == app_code))
|
||||||
if not app_model:
|
if not app_model:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TestForgotPasswordSendEmailApi:
|
|||||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||||
@patch("controllers.web.forgot_password.Session")
|
@patch("controllers.web.forgot_password.sessionmaker")
|
||||||
def test_should_normalize_email_before_sending(
|
def test_should_normalize_email_before_sending(
|
||||||
self,
|
self,
|
||||||
mock_session_cls,
|
mock_session_cls,
|
||||||
@@ -51,7 +51,7 @@ class TestForgotPasswordSendEmailApi:
|
|||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
mock_send_mail.return_value = "token-123"
|
mock_send_mail.return_value = "token-123"
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
@@ -153,7 +153,7 @@ class TestForgotPasswordResetApi:
|
|||||||
|
|
||||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@patch("controllers.web.forgot_password.Session")
|
@patch("controllers.web.forgot_password.sessionmaker")
|
||||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||||
def test_should_fetch_account_with_fallback(
|
def test_should_fetch_account_with_fallback(
|
||||||
@@ -169,7 +169,7 @@ class TestForgotPasswordResetApi:
|
|||||||
mock_account = MagicMock()
|
mock_account = MagicMock()
|
||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
@@ -190,7 +190,7 @@ class TestForgotPasswordResetApi:
|
|||||||
|
|
||||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||||
@patch("controllers.web.forgot_password.Session")
|
@patch("controllers.web.forgot_password.sessionmaker")
|
||||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@@ -208,7 +208,7 @@ class TestForgotPasswordResetApi:
|
|||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
mock_get_account.return_value = account
|
mock_get_account.return_value = account
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
@@ -231,4 +231,3 @@ class TestForgotPasswordResetApi:
|
|||||||
assert account.password == expected_password
|
assert account.password == expected_password
|
||||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||||
assert account.password_salt == expected_salt
|
assert account.password_salt == expected_salt
|
||||||
mock_session.commit.assert_called_once()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user