refactor: use sessionmaker().begin() in console workspace and misc co… (#34284)

This commit is contained in:
Desel72
2026-03-31 17:28:05 +03:00
committed by GitHub
parent 2c8b47ce44
commit dbdbb098d5
8 changed files with 29 additions and 30 deletions

View File

@@ -2,7 +2,7 @@ import flask_restx
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from sqlalchemy import delete, func, select from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
@@ -34,7 +34,7 @@ api_key_list_model = console_ns.model(
def _get_resource(resource_id, tenant_id, resource_model): def _get_resource(resource_id, tenant_id, resource_model):
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
resource = session.execute( resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none() ).scalar_one_or_none()

View File

@@ -2,7 +2,7 @@ from typing import Any
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@@ -74,7 +74,7 @@ class ConversationListApi(InstalledAppResource):
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
pagination = WebConversationService.pagination_by_last_id( pagination = WebConversationService.pagination_by_last_id(
session=session, session=session,
app_model=app_model, app_model=app_model,

View File

@@ -2,7 +2,7 @@ from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
@@ -24,7 +24,7 @@ def plugin_permission_required(
user = current_user user = current_user
tenant_id = current_tenant_id tenant_id = current_tenant_id
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
permission = ( permission = (
session.query(TenantPluginPermission) session.query(TenantPluginPermission)
.where( .where(

View File

@@ -8,7 +8,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
@@ -562,7 +562,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email user_email = current_user.email
else: else:
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None: if account is None:
raise AccountNotFound() raise AccountNotFound()

View File

@@ -7,7 +7,7 @@ from flask import make_response, redirect, request, send_file
from flask_restx import Resource from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
@@ -1019,7 +1019,7 @@ class ToolProviderMCPApi(Resource):
# Step 1: Get provider data for URL validation (short-lived session, no network I/O) # Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None validation_data = None
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
validation_data = service.get_provider_for_url_validation( validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=payload.provider_id tenant_id=current_tenant_id, provider_id=payload.provider_id
@@ -1034,7 +1034,7 @@ class ToolProviderMCPApi(Resource):
) )
# Step 3: Perform database update in a transaction # Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.update_provider( service.update_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
@@ -1061,7 +1061,7 @@ class ToolProviderMCPApi(Resource):
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)
@@ -1079,7 +1079,7 @@ class ToolMCPAuthApi(Resource):
provider_id = payload.provider_id provider_id = payload.provider_id
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider: if not db_provider:
@@ -1100,7 +1100,7 @@ class ToolMCPAuthApi(Resource):
sse_read_timeout=provider_entity.sse_read_timeout, sse_read_timeout=provider_entity.sse_read_timeout,
): ):
# Update credentials in new transaction # Update credentials in new transaction
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.update_provider_credentials( service.update_provider_credentials(
provider_id=provider_id, provider_id=provider_id,
@@ -1118,17 +1118,17 @@ class ToolMCPAuthApi(Resource):
resource_metadata_url=e.resource_metadata_url, resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint, scope_hint=e.scope_hint,
) )
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result) response = service.execute_auth_actions(auth_result)
return response return response
except MCPRefreshTokenError as e: except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e: except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e raise ValueError(f"Failed to connect to MCP server: {e}") from e
@@ -1141,7 +1141,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, provider_id): def get(self, provider_id):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@@ -1155,7 +1155,7 @@ class ToolMCPListAllApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance # Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False) tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
@@ -1170,7 +1170,7 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, provider_id): def get(self, provider_id):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
tools = service.list_provider_tools( tools = service.list_provider_tools(
tenant_id=tenant_id, tenant_id=tenant_id,
@@ -1188,7 +1188,7 @@ class ToolMCPCallbackApi(Resource):
authorization_code = query.code authorization_code = query.code
# Create service instance for handle_callback # Create service instance for handle_callback
with Session(db.engine) as session, session.begin(): with sessionmaker(db.engine).begin() as session:
mcp_service = MCPToolManageService(session=session) mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens # handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code) state_data, tokens = handle_callback(state_key, authorization_code)

View File

@@ -5,7 +5,7 @@ from flask import make_response, redirect, request
from flask_restx import Resource from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config from configs import dify_config
@@ -375,7 +375,7 @@ class TriggerSubscriptionDeleteApi(Resource):
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
try: try:
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
# Delete trigger provider subscription # Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider( TriggerProviderService.delete_trigger_provider(
session=session, session=session,
@@ -388,7 +388,6 @@ class TriggerSubscriptionDeleteApi(Resource):
tenant_id=user.current_tenant_id, tenant_id=user.current_tenant_id,
subscription_id=subscription_id, subscription_id=subscription_id,
) )
session.commit()
return {"result": "success"} return {"result": "success"}
except ValueError as e: except ValueError as e:
raise BadRequest(str(e)) raise BadRequest(str(e))

View File

@@ -69,7 +69,7 @@ def client(flask_app_with_containers):
return_value=(MagicMock(id="u1"), "t1"), return_value=(MagicMock(id="u1"), "t1"),
autospec=True, autospec=True,
) )
@patch("controllers.console.workspace.tool_providers.Session", autospec=True) @patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True)
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True) @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
@@ -88,7 +88,7 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
create_result.id = "provider-1" create_result.id = "provider-1"
svc.create_provider.return_value = create_result svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock() mock_session.return_value.begin.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller # Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True): with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True):
payload = { payload = {

View File

@@ -306,14 +306,14 @@ class TestTriggerSubscriptionCrud:
app.test_request_context("/"), app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db, patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls, patch("controllers.console.workspace.trigger_providers.sessionmaker") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"), patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
patch( patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription" "controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
), ),
): ):
mock_db.engine = MagicMock() mock_db.engine = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
result = method(api, "sub1") result = method(api, "sub1")
@@ -327,14 +327,14 @@ class TestTriggerSubscriptionCrud:
app.test_request_context("/"), app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db, patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as session_cls, patch("controllers.console.workspace.trigger_providers.sessionmaker") as session_cls,
patch( patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider", "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
side_effect=ValueError("bad"), side_effect=ValueError("bad"),
), ),
): ):
mock_db.engine = MagicMock() mock_db.engine = MagicMock()
session_cls.return_value.__enter__.return_value = MagicMock() session_cls.return_value.begin.return_value.__enter__.return_value = MagicMock()
with pytest.raises(BadRequest): with pytest.raises(BadRequest):
method(api, "sub1") method(api, "sub1")