refactor: select in rag_pipeline (#34495)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-03 05:42:01 +02:00
committed by GitHub
parent 33d4fd357c
commit 7eb632eb34
2 changed files with 95 additions and 168 deletions

View File

@@ -156,27 +156,27 @@ class RagPipelineService:
:param template_id: template id :param template_id: template id
:param template_info: template info :param template_info: template info
""" """
customized_template: PipelineCustomizedTemplate | None = ( customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
db.session.query(PipelineCustomizedTemplate) select(PipelineCustomizedTemplate)
.where( .where(
PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
) )
.first() .limit(1)
) )
if not customized_template: if not customized_template:
raise ValueError("Customized pipeline template not found.") raise ValueError("Customized pipeline template not found.")
# check template name is exist # check template name is exist
template_name = template_info.name template_name = template_info.name
if template_name: if template_name:
template = ( template = db.session.scalar(
db.session.query(PipelineCustomizedTemplate) select(PipelineCustomizedTemplate)
.where( .where(
PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id, PipelineCustomizedTemplate.id != template_id,
) )
.first() .limit(1)
) )
if template: if template:
raise ValueError("Template name is already exists") raise ValueError("Template name is already exists")
@@ -192,13 +192,13 @@ class RagPipelineService:
""" """
Delete customized pipeline template. Delete customized pipeline template.
""" """
customized_template: PipelineCustomizedTemplate | None = ( customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
db.session.query(PipelineCustomizedTemplate) select(PipelineCustomizedTemplate)
.where( .where(
PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
) )
.first() .limit(1)
) )
if not customized_template: if not customized_template:
raise ValueError("Customized pipeline template not found.") raise ValueError("Customized pipeline template not found.")
@@ -210,14 +210,14 @@ class RagPipelineService:
Get draft workflow Get draft workflow
""" """
# fetch draft workflow by rag pipeline # fetch draft workflow by rag pipeline
workflow = ( workflow = db.session.scalar(
db.session.query(Workflow) select(Workflow)
.where( .where(
Workflow.tenant_id == pipeline.tenant_id, Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id, Workflow.app_id == pipeline.id,
Workflow.version == "draft", Workflow.version == "draft",
) )
.first() .limit(1)
) )
# return draft workflow # return draft workflow
@@ -232,28 +232,28 @@ class RagPipelineService:
return None return None
# fetch published workflow by workflow_id # fetch published workflow by workflow_id
workflow = ( workflow = db.session.scalar(
db.session.query(Workflow) select(Workflow)
.where( .where(
Workflow.tenant_id == pipeline.tenant_id, Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id, Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id, Workflow.id == pipeline.workflow_id,
) )
.first() .limit(1)
) )
return workflow return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations.""" """Fetch a published workflow snapshot by ID for restore operations."""
workflow = ( workflow = db.session.scalar(
db.session.query(Workflow) select(Workflow)
.where( .where(
Workflow.tenant_id == pipeline.tenant_id, Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id, Workflow.app_id == pipeline.id,
Workflow.id == workflow_id, Workflow.id == workflow_id,
) )
.first() .limit(1)
) )
if workflow and workflow.version == Workflow.VERSION_DRAFT: if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published") raise IsDraftWorkflowError("source workflow must be published")
@@ -974,7 +974,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id: if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first() document = db.session.get(Document, document_id.value)
if document: if document:
document.indexing_status = IndexingStatus.ERROR document.indexing_status = IndexingStatus.ERROR
document.error = error document.error = error
@@ -1178,12 +1178,12 @@ class RagPipelineService:
""" """
Publish customized pipeline template Publish customized pipeline template
""" """
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first() pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found") raise ValueError("Pipeline not found")
if not pipeline.workflow_id: if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found") raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow: if not workflow:
raise ValueError("Workflow not found") raise ValueError("Workflow not found")
with Session(db.engine) as session: with Session(db.engine) as session:
@@ -1194,21 +1194,21 @@ class RagPipelineService:
# check template name is exist # check template name is exist
template_name = args.get("name") template_name = args.get("name")
if template_name: if template_name:
template = ( template = db.session.scalar(
db.session.query(PipelineCustomizedTemplate) select(PipelineCustomizedTemplate)
.where( .where(
PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
) )
.first() .limit(1)
) )
if template: if template:
raise ValueError("Template name is already exists") raise ValueError("Template name is already exists")
max_position = ( max_position = db.session.scalar(
db.session.query(func.max(PipelineCustomizedTemplate.position)) select(func.max(PipelineCustomizedTemplate.position)).where(
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
.scalar() )
) )
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@@ -1239,13 +1239,14 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool: def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return ( return (
db.session.query(Workflow) db.session.scalar(
.where( select(func.count(Workflow.id)).where(
Workflow.tenant_id == pipeline.tenant_id, Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id, Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT, Workflow.version == Workflow.VERSION_DRAFT,
) )
.count() )
or 0
) > 0 ) > 0
def get_node_last_run( def get_node_last_run(
@@ -1353,11 +1354,11 @@ class RagPipelineService:
def get_recommended_plugins(self, type: str) -> dict: def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins # Query active recommended plugins
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all": if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type) stmt = stmt.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all() pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
if not pipeline_recommended_plugins: if not pipeline_recommended_plugins:
return { return {
@@ -1396,14 +1397,12 @@ class RagPipelineService:
""" """
Retry error document Retry error document
""" """
document_pipeline_execution_log = ( document_pipeline_execution_log = db.session.scalar(
db.session.query(DocumentPipelineExecutionLog) select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
) )
if not document_pipeline_execution_log: if not document_pipeline_execution_log:
raise ValueError("Document pipeline execution log not found") raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first() pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found") raise ValueError("Pipeline not found")
# convert to app config # convert to app config
@@ -1432,23 +1431,23 @@ class RagPipelineService:
""" """
Get datasource plugins Get datasource plugins
""" """
dataset: Dataset | None = ( dataset: Dataset | None = db.session.scalar(
db.session.query(Dataset) select(Dataset)
.where( .where(
Dataset.id == dataset_id, Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id, Dataset.tenant_id == tenant_id,
) )
.first() .limit(1)
) )
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
pipeline: Pipeline | None = ( pipeline: Pipeline | None = db.session.scalar(
db.session.query(Pipeline) select(Pipeline)
.where( .where(
Pipeline.id == dataset.pipeline_id, Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id, Pipeline.tenant_id == tenant_id,
) )
.first() .limit(1)
) )
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found") raise ValueError("Pipeline not found")
@@ -1530,23 +1529,23 @@ class RagPipelineService:
""" """
Get pipeline Get pipeline
""" """
dataset: Dataset | None = ( dataset: Dataset | None = db.session.scalar(
db.session.query(Dataset) select(Dataset)
.where( .where(
Dataset.id == dataset_id, Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id, Dataset.tenant_id == tenant_id,
) )
.first() .limit(1)
) )
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
pipeline: Pipeline | None = ( pipeline: Pipeline | None = db.session.scalar(
db.session.query(Pipeline) select(Pipeline)
.where( .where(
Pipeline.id == dataset.pipeline_id, Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id, Pipeline.tenant_id == tenant_id,
) )
.first() .limit(1)
) )
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found") raise ValueError("Pipeline not found")

View File

@@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None: def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
first_query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
first_query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
with pytest.raises(ValueError, match="Dataset not found"): with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1") rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
@@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
def test_update_customized_pipeline_template_success(mocker) -> None: def test_update_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
# First query finds the template, second query (duplicate check) returns None # First scalar finds the template, second scalar (duplicate check) returns None
query_mock_1 = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
query_mock_1.where.return_value.first.return_value = template
query_mock_2 = mocker.Mock()
query_mock_2.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
def test_update_customized_pipeline_template_not_found(mocker) -> None: def test_update_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i")) info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
@@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
duplicate = SimpleNamespace(name="dup") duplicate = SimpleNamespace(name="dup")
query_mock = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
query_mock.where.return_value.first.side_effect = [template, duplicate]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i")) info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
@@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
def test_delete_customized_pipeline_template_success(mocker) -> None: def test_delete_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(id="tpl-1") template = SimpleNamespace(id="tpl-1")
query_mock = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
query_mock.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete") delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
def test_delete_customized_pipeline_template_not_found(mocker) -> None: def test_delete_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
with pytest.raises(ValueError, match="Customized pipeline template not found"): with pytest.raises(ValueError, match="Customized pipeline template not found"):
@@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None: def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
query_mock.where.return_value.count.return_value = 1
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
pipeline = SimpleNamespace(tenant_id="t1", id="p1") pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is True assert rag_pipeline_service.is_workflow_exist(pipeline) is True
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None: def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
query_mock.where.return_value.count.return_value = 0
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
pipeline = SimpleNamespace(tenant_id="t1", id="p1") pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is False assert rag_pipeline_service.is_workflow_exist(pipeline) is False
@@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None: def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate from models.dataset import Pipeline
from models.workflow import Workflow
# 1. Setup mocks # 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline) pipeline = mocker.Mock(spec=Pipeline)
@@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
# Mock db itself to avoid app context errors # Mock db itself to avoid app context errors
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
# Improved mocking for session.query # Mock get() for Pipeline and Workflow PK lookups
def mock_query_side_effect(model): mock_db.session.get.side_effect = [pipeline, workflow]
m = mocker.Mock() # Mock scalar() for template name check (None) and max position (5)
if model == Pipeline: mock_db.session.scalar.side_effect = [None, 5]
m.where.return_value.first.return_value = pipeline
elif model == Workflow:
m.where.return_value.first.return_value = workflow
elif model == PipelineCustomizedTemplate:
m.where.return_value.first.return_value = None
elif model == Dataset:
m.where.return_value.first.return_value = mocker.Mock()
else:
# For func.max cases
m.where.return_value.scalar.return_value = 5
m.where.return_value.first.return_value = mocker.Mock()
return m
mock_db.session.query.side_effect = mock_query_side_effect
# Mock retrieve_dataset # Mock retrieve_dataset
dataset = mocker.Mock() dataset = mocker.Mock()
pipeline.retrieve_dataset.return_value = dataset pipeline.retrieve_dataset.return_value = dataset
# Mock max position
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
mocker.patch(
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
return_value=5,
)
# Mock RagPipelineDslService # Mock RagPipelineDslService
mock_dsl_service = mocker.Mock() mock_dsl_service = mocker.Mock()
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"} mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
@@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
workflow.rag_pipeline_variables = [] workflow.rag_pipeline_variables = []
# Mock queries # Mock queries
mock_query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock() workflow = mocker.Mock()
# Mock queries # Mock queries: Log lookup via scalar, Pipeline lookup via get
mock_query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
# Log lookup, then Pipeline lookup mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mock_query.where.return_value.first.side_effect = [log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
# Mock db aggressively # Mock db aggressively
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock() mock_db.engine = mocker.Mock()
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock() mock_db.session.scalar.return_value = mocker.Mock()
pipeline = mocker.Mock(spec=Pipeline) pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p-1" pipeline.id = "p-1"
@@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow) workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow mock_db.session.scalar.return_value = workflow
# 2. Run test # 2. Run test
result = rag_pipeline_service.get_draft_workflow(pipeline) result = rag_pipeline_service.get_draft_workflow(pipeline)
@@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow) workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow mock_db.session.scalar.return_value = workflow
# 2. Run test # 2. Run test
result = rag_pipeline_service.get_published_workflow(pipeline) result = rag_pipeline_service.get_published_workflow(pipeline)
@@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None: def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = []
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query mock_db.session.scalars.return_value.all.return_value = []
result = rag_pipeline_service.get_recommended_plugins("all") result = rag_pipeline_service.get_recommended_plugins("all")
@@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None: def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
plugin_a = SimpleNamespace(plugin_id="plugin-a") plugin_a = SimpleNamespace(plugin_id="plugin-a")
plugin_b = SimpleNamespace(plugin_id="plugin-b") plugin_b = SimpleNamespace(plugin_id="plugin-b")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch( mocker.patch(
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", "services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
@@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None: def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
with pytest.raises(ValueError, match="Document pipeline execution log not found"): with pytest.raises(ValueError, match="Document pipeline execution log not found"):
rag_pipeline_service.retry_error_document( rag_pipeline_service.retry_error_document(
@@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None: def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1") dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1") pipeline = SimpleNamespace(id="p1", tenant_id="t1")
query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Pipeline or workflow not found"): with pytest.raises(ValueError, match="Pipeline or workflow not found"):
@@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
document = SimpleNamespace(indexing_status="waiting", error=None) document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.return_value = document mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add") add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None: def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
with pytest.raises(ValueError, match="Pipeline not found"): with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {}) rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None: def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None) pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
query = mocker.Mock() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
query.where.return_value.first.return_value = pipeline
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
with pytest.raises(ValueError, match="Pipeline workflow not found"): with pytest.raises(ValueError, match="Pipeline workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"}) rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
@@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.return_value = None mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
with pytest.raises(ValueError, match="Dataset not found"): with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("t1", "d1") rag_pipeline_service.get_pipeline("t1", "d1")
@@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1") dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, None] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
with pytest.raises(ValueError, match="Pipeline not found"): with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.get_pipeline("t1", "d1") rag_pipeline_service.get_pipeline("t1", "d1")
@@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None: def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.return_value = template mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, None] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
with pytest.raises(ValueError, match="Workflow not found"): with pytest.raises(ValueError, match="Workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {}) rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
workflow = SimpleNamespace(id="wf-1") workflow = SimpleNamespace(id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, workflow]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock() mock_db.engine = mocker.Mock()
mock_db.session.query.return_value = query mock_db.session.get.side_effect = [pipeline, workflow]
session_ctx = mocker.MagicMock() session_ctx = mocker.MagicMock()
session_ctx.__enter__.return_value = SimpleNamespace() session_ctx.__enter__.return_value = SimpleNamespace()
session_ctx.__exit__.return_value = False session_ctx.__exit__.return_value = False
@@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None: def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
plugin = SimpleNamespace(plugin_id="plugin-a") plugin = SimpleNamespace(plugin_id="plugin-a")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query mock_db.session.scalars.return_value.all.return_value = [plugin]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[]) mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[]) mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
@@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None: def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1") exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, None] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"): with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.retry_error_document( rag_pipeline_service.retry_error_document(
@@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
exec_log = SimpleNamespace(pipeline_id="p1") exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1") pipeline = SimpleNamespace(id="p1")
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, pipeline] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Workflow not found"): with pytest.raises(ValueError, match="Workflow not found"):
@@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[] graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
) )
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == [] assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
@@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}], rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
) )
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow) mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch( mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[] "services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
@@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
], ],
) )
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow) mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch( mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", "services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
@@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
dataset = SimpleNamespace(pipeline_id="p1") dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1") pipeline = SimpleNamespace(id="p1")
query = mocker.Mock() query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline] mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
result = rag_pipeline_service.get_pipeline("t1", "d1") result = rag_pipeline_service.get_pipeline("t1", "d1")