diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 50f34d5a8ad..5b3668aebbd 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -156,27 +156,27 @@ class RagPipelineService: :param template_id: template id :param template_info: template info """ - customized_template: PipelineCustomizedTemplate | None = ( - db.session.query(PipelineCustomizedTemplate) + customized_template: PipelineCustomizedTemplate | None = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) - .first() + .limit(1) ) if not customized_template: raise ValueError("Customized pipeline template not found.") # check template name is exist template_name = template_info.name if template_name: - template = ( - db.session.query(PipelineCustomizedTemplate) + template = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.id != template_id, ) - .first() + .limit(1) ) if template: raise ValueError("Template name is already exists") @@ -192,13 +192,13 @@ class RagPipelineService: """ Delete customized pipeline template. """ - customized_template: PipelineCustomizedTemplate | None = ( - db.session.query(PipelineCustomizedTemplate) + customized_template: PipelineCustomizedTemplate | None = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) - .first() + .limit(1) ) if not customized_template: raise ValueError("Customized pipeline template not found.") @@ -210,14 +210,14 @@ class RagPipelineService: Get draft workflow """ # fetch draft workflow by rag pipeline - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() + .limit(1) ) # return draft workflow @@ -232,28 +232,28 @@ class RagPipelineService: return None # fetch published workflow by workflow_id - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == pipeline.workflow_id, ) - .first() + .limit(1) ) return workflow def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: """Fetch a published workflow snapshot by ID for restore operations.""" - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id, ) - .first() + .limit(1) ) if workflow and workflow.version == Workflow.VERSION_DRAFT: raise IsDraftWorkflowError("source workflow must be published") @@ -974,7 +974,7 @@ class RagPipelineService: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: - document = db.session.query(Document).where(Document.id == document_id.value).first() + document = db.session.get(Document, document_id.value) if document: document.indexing_status = IndexingStatus.ERROR document.error = error @@ -1178,12 +1178,12 @@ class RagPipelineService: """ Publish customized pipeline template """ - pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first() + pipeline = db.session.get(Pipeline, pipeline_id) if not pipeline: raise ValueError("Pipeline not found") if not pipeline.workflow_id: raise ValueError("Pipeline workflow not found") - workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() + workflow = db.session.get(Workflow, pipeline.workflow_id) if not workflow: raise ValueError("Workflow not found") with Session(db.engine) as session: @@ -1194,21 +1194,21 @@ class RagPipelineService: # check template name is exist template_name = args.get("name") if template_name: - template = ( - db.session.query(PipelineCustomizedTemplate) + template = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, ) - .first() + .limit(1) ) if template: raise ValueError("Template name is already exists") - max_position = ( - db.session.query(func.max(PipelineCustomizedTemplate.position)) - .where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) - .scalar() + max_position = db.session.scalar( + select(func.max(PipelineCustomizedTemplate.position)).where( + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id + ) ) from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService @@ -1239,13 +1239,14 @@ class RagPipelineService: def is_workflow_exist(self, pipeline: Pipeline) -> bool: return ( - db.session.query(Workflow) - .where( - Workflow.tenant_id == pipeline.tenant_id, - Workflow.app_id == pipeline.id, - Workflow.version == Workflow.VERSION_DRAFT, + db.session.scalar( + select(func.count(Workflow.id)).where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) ) - .count() + or 0 ) > 0 def get_node_last_run( @@ -1353,11 +1354,11 @@ class RagPipelineService: def get_recommended_plugins(self, type: str) -> dict: # Query active recommended plugins - query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) + stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) if type and type != "all": - query = query.where(PipelineRecommendedPlugin.type == type) + stmt = stmt.where(PipelineRecommendedPlugin.type == type) - pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all() + pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all() if not pipeline_recommended_plugins: return { @@ -1396,14 +1397,12 @@ class RagPipelineService: """ Retry error document """ - document_pipeline_execution_log = ( - db.session.query(DocumentPipelineExecutionLog) - .where(DocumentPipelineExecutionLog.document_id == document.id) - .first() + document_pipeline_execution_log = db.session.scalar( + select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1) ) if not document_pipeline_execution_log: raise ValueError("Document pipeline execution log not found") - pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first() + pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id) if not pipeline: raise ValueError("Pipeline not found") # convert to app config @@ -1432,23 +1431,23 @@ class RagPipelineService: """ Get datasource plugins """ - dataset: Dataset | None = ( - db.session.query(Dataset) + dataset: Dataset | None = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = ( - db.session.query(Pipeline) + pipeline: Pipeline | None = db.session.scalar( + select(Pipeline) .where( Pipeline.id == dataset.pipeline_id, Pipeline.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not pipeline: raise ValueError("Pipeline not found") @@ -1530,23 +1529,23 @@ class RagPipelineService: """ Get pipeline """ - dataset: Dataset | None = ( - db.session.query(Dataset) + dataset: Dataset | None = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = ( - db.session.query(Pipeline) + pipeline: Pipeline | None = db.session.scalar( + select(Pipeline) .where( Pipeline.id == dataset.pipeline_id, Pipeline.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not pipeline: raise ValueError("Pipeline not found") diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py index cb3c2d742d9..f270ee0fde5 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py @@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None: - first_query = mocker.Mock() - first_query.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) with pytest.raises(ValueError, match="Dataset not found"): rag_pipeline_service.get_pipeline("tenant-1", "dataset-1") @@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service def test_update_customized_pipeline_template_success(mocker) -> None: template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) - # First query finds the template, second query (duplicate check) returns None - query_mock_1 = mocker.Mock() - query_mock_1.where.return_value.first.return_value = template - query_mock_2 = mocker.Mock() - query_mock_2.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2]) + # First scalar finds the template, second scalar (duplicate check) returns None + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None]) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) @@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None: def test_update_customized_pipeline_template_not_found(mocker) -> None: - query_mock = mocker.Mock() - query_mock.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i")) @@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None: template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) duplicate = SimpleNamespace(name="dup") - query_mock = mocker.Mock() - query_mock.where.return_value.first.side_effect = [template, duplicate] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate]) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i")) @@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None: def test_delete_customized_pipeline_template_success(mocker) -> None: template = SimpleNamespace(id="tpl-1") - query_mock = mocker.Mock() - query_mock.where.return_value.first.return_value = template - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template) delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete") commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") @@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None: def test_delete_customized_pipeline_template_not_found(mocker) -> None: - query_mock = mocker.Mock() - query_mock.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) with pytest.raises(ValueError, match="Customized pipeline template not found"): @@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) - def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None: - query_mock = mocker.Mock() - query_mock.where.return_value.count.return_value = 1 - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1) pipeline = SimpleNamespace(tenant_id="t1", id="p1") assert rag_pipeline_service.is_workflow_exist(pipeline) is True def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None: - query_mock = mocker.Mock() - query_mock.where.return_value.count.return_value = 0 - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0) pipeline = SimpleNamespace(tenant_id="t1", id="p1") assert rag_pipeline_service.is_workflow_exist(pipeline) is False @@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None: - from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate - from models.workflow import Workflow + from models.dataset import Pipeline # 1. Setup mocks pipeline = mocker.Mock(spec=Pipeline) @@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi # Mock db itself to avoid app context errors mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") - # Improved mocking for session.query - def mock_query_side_effect(model): - m = mocker.Mock() - if model == Pipeline: - m.where.return_value.first.return_value = pipeline - elif model == Workflow: - m.where.return_value.first.return_value = workflow - elif model == PipelineCustomizedTemplate: - m.where.return_value.first.return_value = None - elif model == Dataset: - m.where.return_value.first.return_value = mocker.Mock() - else: - # For func.max cases - m.where.return_value.scalar.return_value = 5 - m.where.return_value.first.return_value = mocker.Mock() - return m - - mock_db.session.query.side_effect = mock_query_side_effect + # Mock get() for Pipeline and Workflow PK lookups + mock_db.session.get.side_effect = [pipeline, workflow] + # Mock scalar() for template name check (None) and max position (5) + mock_db.session.scalar.side_effect = [None, 5] # Mock retrieve_dataset dataset = mocker.Mock() pipeline.retrieve_dataset.return_value = dataset - # Mock max position - mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1) - mocker.patch( - "services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar", - return_value=5, - ) - # Mock RagPipelineDslService mock_dsl_service = mocker.Mock() mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"} @@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None: workflow.rag_pipeline_variables = [] # Mock queries - mock_query = mocker.Mock() - mock_query.where.return_value.first.side_effect = [dataset, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) @@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None: workflow = mocker.Mock() - # Mock queries - mock_query = mocker.Mock() - # Log lookup, then Pipeline lookup - mock_query.where.return_value.first.side_effect = [log, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query) + # Mock queries: Log lookup via scalar, Pipeline lookup via get + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) @@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None: # Mock db aggressively mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.engine = mocker.Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock() + mock_db.session.scalar.return_value = mocker.Mock() pipeline = mocker.Mock(spec=Pipeline) pipeline.id = "p-1" @@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None: workflow = mocker.Mock(spec=Workflow) mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") - mock_db.session.query.return_value.where.return_value.first.return_value = workflow + mock_db.session.scalar.return_value = workflow # 2. Run test result = rag_pipeline_service.get_draft_workflow(pipeline) @@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None: workflow = mocker.Mock(spec=Workflow) mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") - mock_db.session.query.return_value.where.return_value.first.return_value = workflow + mock_db.session.scalar.return_value = workflow # 2. Run test result = rag_pipeline_service.get_published_workflow(pipeline) @@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None: - query = mocker.Mock() - query.where.return_value = query - query.order_by.return_value.all.return_value = [] mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") - mock_db.session.query.return_value = query + mock_db.session.scalars.return_value.all.return_value = [] result = rag_pipeline_service.get_recommended_plugins("all") @@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None: plugin_a = SimpleNamespace(plugin_id="plugin-a") plugin_b = SimpleNamespace(plugin_id="plugin-b") - query = mocker.Mock() - query.where.return_value = query - query.order_by.return_value.all.return_value = [plugin_a, plugin_b] mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") - mock_db.session.query.return_value = query + mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b] mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch( "services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", @@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None: - query = mocker.Mock() - query.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) with pytest.raises(ValueError, match="Document pipeline execution log not found"): rag_pipeline_service.retry_error_document( @@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None: dataset = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1", tenant_id="t1") - query = mocker.Mock() - query.where.return_value.first.side_effect = [dataset, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None) with pytest.raises(ValueError, match="Pipeline or workflow not found"): @@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker document = SimpleNamespace(indexing_status="waiting", error=None) query = mocker.Mock() - query.where.return_value.first.return_value = document - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document) add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add") commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") @@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None: - query = mocker.Mock() - query.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None) with pytest.raises(ValueError, match="Pipeline not found"): rag_pipeline_service.publish_customized_pipeline_template("p1", {}) @@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None: pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None) - query = mocker.Mock() - query.where.return_value.first.return_value = pipeline - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline) with pytest.raises(ValueError, match="Pipeline workflow not found"): rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"}) @@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: query = mocker.Mock() - query.where.return_value.first.return_value = None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) with pytest.raises(ValueError, match="Dataset not found"): rag_pipeline_service.get_pipeline("t1", "d1") @@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: dataset = SimpleNamespace(pipeline_id="p1") query = mocker.Mock() - query.where.return_value.first.side_effect = [dataset, None] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None]) with pytest.raises(ValueError, match="Pipeline not found"): rag_pipeline_service.get_pipeline("t1", "d1") @@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None: def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None: template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) query = mocker.Mock() - query.where.return_value.first.return_value = template - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template) commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) @@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") query = mocker.Mock() - query.where.return_value.first.side_effect = [pipeline, None] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None]) with pytest.raises(ValueError, match="Workflow not found"): rag_pipeline_service.publish_customized_pipeline_template("p1", {}) @@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") workflow = SimpleNamespace(id="wf-1") - query = mocker.Mock() - query.where.return_value.first.side_effect = [pipeline, workflow] mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.engine = mocker.Mock() - mock_db.session.query.return_value = query + mock_db.session.get.side_effect = [pipeline, workflow] session_ctx = mocker.MagicMock() session_ctx.__enter__.return_value = SimpleNamespace() session_ctx.__exit__.return_value = False @@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None: plugin = SimpleNamespace(plugin_id="plugin-a") - query = mocker.Mock() - query.where.return_value = query - query.order_by.return_value.all.return_value = [plugin] mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") - mock_db.session.query.return_value = query + mock_db.session.scalars.return_value.all.return_value = [plugin] mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[]) mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[]) @@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: exec_log = SimpleNamespace(pipeline_id="p1") query = mocker.Mock() - query.where.return_value.first.side_effect = [exec_log, None] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None) with pytest.raises(ValueError, match="Pipeline not found"): rag_pipeline_service.retry_error_document( @@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_ exec_log = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1") query = mocker.Mock() - query.where.return_value.first.side_effect = [exec_log, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None) with pytest.raises(ValueError, match="Workflow not found"): @@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[] ) query = mocker.Mock() - query.where.return_value.first.side_effect = [dataset, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == [] @@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published( rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}], ) query = mocker.Mock() - query.where.return_value.first.side_effect = [dataset, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow) mocker.patch( "services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[] @@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag ], ) query = mocker.Mock() - query.where.return_value.first.side_effect = [dataset, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch( "services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", @@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) dataset = SimpleNamespace(pipeline_id="p1") pipeline = SimpleNamespace(id="p1") query = mocker.Mock() - query.where.return_value.first.side_effect = [dataset, pipeline] - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) + mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) result = rag_pipeline_service.get_pipeline("t1", "d1")