refactor: use sessionmaker().begin() in console datasets controllers (#34283)

This commit is contained in:
Desel72
2026-03-31 16:09:18 +03:00
committed by GitHub
parent b818cc0766
commit d9a0665b2c
7 changed files with 34 additions and 42 deletions

View File

@@ -6,7 +6,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.schema import get_or_create_model, register_schema_model
@@ -159,7 +159,7 @@ class DataSourceApi(Resource):
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
binding_id = str(binding_id)
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
@@ -211,7 +211,7 @@ class DataSourceNotionListApi(Resource):
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
# import notion in the exist dataset
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)

View File

@@ -3,7 +3,7 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@@ -85,7 +85,7 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)

View File

@@ -1,6 +1,6 @@
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
import services
@@ -54,7 +54,7 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=payload.yaml_content,
)
try:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_tenant_id,

View File

@@ -5,7 +5,7 @@ from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
@@ -96,7 +96,7 @@ class RagPipelineVariableCollectionApi(Resource):
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@@ -143,7 +143,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@@ -289,7 +289,7 @@ class RagPipelineVariableResetApi(Resource):
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)

View File

@@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
@@ -68,7 +68,7 @@ class RagPipelineImportApi(Resource):
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Create service with session
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Import app
account = current_user
@@ -80,7 +80,6 @@ class RagPipelineImportApi(Resource):
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
session.commit()
# Return appropriate status code based on result
status = result.status
@@ -102,12 +101,11 @@ class RagPipelineImportConfirmApi(Resource):
current_user, _ = current_account_with_tenant()
# Create service with session
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@@ -124,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@@ -142,7 +140,7 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=query.include_secret == "true"

View File

@@ -6,7 +6,7 @@ from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
@@ -608,7 +608,7 @@ class PublishedRagPipelineApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
pipeline = session.merge(pipeline)
workflow = rag_pipeline_service.publish_workflow(
session=session,
@@ -620,8 +620,6 @@ class PublishedRagPipelineApi(Resource):
session.add(pipeline)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
@@ -695,7 +693,7 @@ class PublishedAllRagPipelineApi(Resource):
raise Forbidden()
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
session=session,
pipeline=pipeline,
@@ -767,7 +765,7 @@ class RagPipelineByIdApi(Resource):
rag_pipeline_service = RagPipelineService()
# Create a session and manage the transaction
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow = rag_pipeline_service.update_workflow(
session=session,
workflow_id=workflow_id,
@@ -779,9 +777,6 @@ class RagPipelineByIdApi(Resource):
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@@ -798,14 +793,13 @@ class RagPipelineByIdApi(Resource):
workflow_service = WorkflowService()
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
try:
workflow_service.delete_workflow(
session=session,
workflow_id=workflow_id,
tenant_id=pipeline.tenant_id,
)
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:

View File

@@ -102,12 +102,12 @@ class TestDataSourceApi:
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "enable")
@@ -123,12 +123,12 @@ class TestDataSourceApi:
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "disable")
@@ -142,10 +142,10 @@ class TestDataSourceApi:
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(NotFound):
@@ -159,10 +159,10 @@ class TestDataSourceApi:
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
@@ -176,10 +176,10 @@ class TestDataSourceApi:
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
@@ -282,7 +282,7 @@ class TestDataSourceNotionListApi:
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
@@ -292,7 +292,7 @@ class TestDataSourceNotionListApi:
),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = [document]
response, status = method(api)
@@ -315,7 +315,7 @@ class TestDataSourceNotionListApi:
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session"),
patch("controllers.console.datasets.data_source.sessionmaker"),
):
with pytest.raises(ValueError):
method(api)