mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
test: added test for api/services/rag_pipeline folder (#33222)
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
This commit is contained in:
@@ -574,7 +574,7 @@ class RagPipelineService:
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
if workflow_node_execution_db_model is not None:
|
||||
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
||||
|
||||
def test_get_type() -> None:
|
||||
retrieval = BuiltInPipelineTemplateRetrieval()
|
||||
|
||||
assert retrieval.get_type() == PipelineTemplateType.BUILTIN
|
||||
|
||||
|
||||
def test_get_pipeline_templates(mocker) -> None:
|
||||
mocker.patch.object(
|
||||
BuiltInPipelineTemplateRetrieval,
|
||||
"_get_builtin_data",
|
||||
return_value={
|
||||
"pipeline_templates": {
|
||||
"en-US": {"pipeline_templates": [{"id": "tpl-1"}]},
|
||||
"tpl-1": {"id": "tpl-1", "name": "Template 1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
retrieval = BuiltInPipelineTemplateRetrieval()
|
||||
|
||||
templates = retrieval.get_pipeline_templates("en-US")
|
||||
|
||||
assert templates == {"pipeline_templates": [{"id": "tpl-1"}]}
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail(mocker) -> None:
|
||||
mocker.patch.object(
|
||||
BuiltInPipelineTemplateRetrieval,
|
||||
"_get_builtin_data",
|
||||
return_value={
|
||||
"pipeline_templates": {
|
||||
"tpl-1": {"id": "tpl-1", "name": "Template 1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
retrieval = BuiltInPipelineTemplateRetrieval()
|
||||
|
||||
detail = retrieval.get_pipeline_template_detail("tpl-1")
|
||||
|
||||
assert detail == {"id": "tpl-1", "name": "Template 1"}
|
||||
|
||||
|
||||
def test_get_pipeline_templates_missing_language_returns_empty_dict(mocker) -> None:
|
||||
mocker.patch.object(
|
||||
BuiltInPipelineTemplateRetrieval,
|
||||
"_get_builtin_data",
|
||||
return_value={"pipeline_templates": {}},
|
||||
)
|
||||
retrieval = BuiltInPipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_templates("fr-FR")
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail_returns_none_for_unknown_id(mocker) -> None:
|
||||
mocker.patch.object(
|
||||
BuiltInPipelineTemplateRetrieval,
|
||||
"_get_builtin_data",
|
||||
return_value={"pipeline_templates": {"tpl-1": {"id": "tpl-1"}}},
|
||||
)
|
||||
retrieval = BuiltInPipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_template_detail("nonexistent-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_builtin_data_reads_from_file_and_caches(mocker) -> None:
|
||||
import json
|
||||
|
||||
# Ensure no cached data
|
||||
BuiltInPipelineTemplateRetrieval.builtin_data = None
|
||||
|
||||
mock_app = mocker.Mock()
|
||||
mock_app.root_path = "/fake/root"
|
||||
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.built_in.built_in_retrieval.current_app",
|
||||
mock_app,
|
||||
)
|
||||
|
||||
test_data = {"pipeline_templates": {"en-US": {"templates": []}}}
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.built_in.built_in_retrieval.Path.read_text",
|
||||
return_value=json.dumps(test_data),
|
||||
)
|
||||
|
||||
result = BuiltInPipelineTemplateRetrieval._get_builtin_data()
|
||||
|
||||
assert result == test_data
|
||||
assert BuiltInPipelineTemplateRetrieval.builtin_data == test_data
|
||||
|
||||
# Reset class state
|
||||
BuiltInPipelineTemplateRetrieval.builtin_data = None
|
||||
|
||||
|
||||
def test_get_builtin_data_returns_cache_on_second_call(mocker) -> None:
|
||||
cached_data = {"pipeline_templates": {"en-US": {}}}
|
||||
BuiltInPipelineTemplateRetrieval.builtin_data = cached_data
|
||||
|
||||
result = BuiltInPipelineTemplateRetrieval._get_builtin_data()
|
||||
|
||||
assert result == cached_data
|
||||
|
||||
# Reset class state
|
||||
BuiltInPipelineTemplateRetrieval.builtin_data = None
|
||||
@@ -0,0 +1,89 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
||||
|
||||
def test_get_pipeline_templates(mocker) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.customized.customized_retrieval.current_account_with_tenant",
|
||||
return_value=("account-id", "tenant-id"),
|
||||
)
|
||||
customized_template = SimpleNamespace(
|
||||
id="tpl-1",
|
||||
name="Custom Template",
|
||||
description="desc",
|
||||
icon={"background": "#fff"},
|
||||
position=2,
|
||||
chunk_structure="parent-child",
|
||||
)
|
||||
scalars_mock = mocker.Mock()
|
||||
scalars_mock.all.return_value = [customized_template]
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.customized.customized_retrieval.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
retrieval = CustomizedPipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_templates("en-US")
|
||||
|
||||
assert retrieval.get_type() == PipelineTemplateType.CUSTOMIZED
|
||||
assert result == {
|
||||
"pipeline_templates": [
|
||||
{
|
||||
"id": "tpl-1",
|
||||
"name": "Custom Template",
|
||||
"description": "desc",
|
||||
"icon": {"background": "#fff"},
|
||||
"position": 2,
|
||||
"chunk_structure": "parent-child",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail_returns_detail(mocker) -> None:
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = SimpleNamespace(
|
||||
id="tpl-1",
|
||||
name="Custom Template",
|
||||
icon={"background": "#fff"},
|
||||
description="desc",
|
||||
chunk_structure="parent-child",
|
||||
yaml_content="workflow:\n graph:\n edges: []",
|
||||
created_user_name="creator",
|
||||
)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.customized.customized_retrieval.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
retrieval = CustomizedPipelineTemplateRetrieval()
|
||||
|
||||
detail = retrieval.get_pipeline_template_detail("tpl-1")
|
||||
|
||||
assert detail == {
|
||||
"id": "tpl-1",
|
||||
"name": "Custom Template",
|
||||
"icon_info": {"background": "#fff"},
|
||||
"description": "desc",
|
||||
"chunk_structure": "parent-child",
|
||||
"export_data": "workflow:\n graph:\n edges: []",
|
||||
"graph": {"edges": []},
|
||||
"created_by": "creator",
|
||||
}
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail_returns_none_when_not_found(mocker) -> None:
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = None
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.customized.customized_retrieval.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
retrieval = CustomizedPipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_template_detail("missing")
|
||||
|
||||
assert result is None
|
||||
@@ -0,0 +1,87 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
||||
|
||||
def test_get_pipeline_templates(mocker) -> None:
|
||||
built_in_template = SimpleNamespace(
|
||||
id="tpl-1",
|
||||
name="Template 1",
|
||||
description="desc",
|
||||
icon={"background": "#fff"},
|
||||
copyright="copyright",
|
||||
privacy_policy="https://example.com/privacy",
|
||||
position=1,
|
||||
chunk_structure="general",
|
||||
)
|
||||
scalars_mock = mocker.Mock()
|
||||
scalars_mock.all.return_value = [built_in_template]
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.database.database_retrieval.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
retrieval = DatabasePipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_templates("en-US")
|
||||
|
||||
assert retrieval.get_type() == PipelineTemplateType.DATABASE
|
||||
assert result == {
|
||||
"pipeline_templates": [
|
||||
{
|
||||
"id": "tpl-1",
|
||||
"name": "Template 1",
|
||||
"description": "desc",
|
||||
"icon": {"background": "#fff"},
|
||||
"copyright": "copyright",
|
||||
"privacy_policy": "https://example.com/privacy",
|
||||
"position": 1,
|
||||
"chunk_structure": "general",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail_returns_detail(mocker) -> None:
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = SimpleNamespace(
|
||||
id="tpl-1",
|
||||
name="Template 1",
|
||||
icon={"background": "#fff"},
|
||||
description="desc",
|
||||
chunk_structure="general",
|
||||
yaml_content="workflow:\n graph:\n nodes: []",
|
||||
)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.database.database_retrieval.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
retrieval = DatabasePipelineTemplateRetrieval()
|
||||
|
||||
detail = retrieval.get_pipeline_template_detail("tpl-1")
|
||||
|
||||
assert detail == {
|
||||
"id": "tpl-1",
|
||||
"name": "Template 1",
|
||||
"icon_info": {"background": "#fff"},
|
||||
"description": "desc",
|
||||
"chunk_structure": "general",
|
||||
"export_data": "workflow:\n graph:\n nodes: []",
|
||||
"graph": {"nodes": []},
|
||||
}
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail_returns_none_when_not_found(mocker) -> None:
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = None
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.database.database_retrieval.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
retrieval = DatabasePipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_template_detail("missing")
|
||||
|
||||
assert result is None
|
||||
@@ -0,0 +1,19 @@
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_name",
|
||||
[
|
||||
"services.rag_pipeline.pipeline_template",
|
||||
"services.rag_pipeline.pipeline_template.built_in",
|
||||
"services.rag_pipeline.pipeline_template.customized",
|
||||
"services.rag_pipeline.pipeline_template.database",
|
||||
"services.rag_pipeline.pipeline_template.remote",
|
||||
],
|
||||
)
|
||||
def test_package_imports(module_name: str) -> None:
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
assert module is not None
|
||||
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
|
||||
|
||||
class DummyRetrieval(PipelineTemplateRetrievalBase):
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
return {"language": language}
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
return {"id": template_id}
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
|
||||
class MissingTypeRetrieval(PipelineTemplateRetrievalBase):
|
||||
def get_pipeline_templates(self, language: str) -> dict:
|
||||
return {"language": language}
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
return {"id": template_id}
|
||||
|
||||
|
||||
def test_pipeline_template_retrieval_base_concrete_implementation() -> None:
|
||||
retrieval = DummyRetrieval()
|
||||
|
||||
assert retrieval.get_pipeline_templates("en-US") == {"language": "en-US"}
|
||||
assert retrieval.get_pipeline_template_detail("tpl-1") == {"id": "tpl-1"}
|
||||
assert retrieval.get_type() == "dummy"
|
||||
|
||||
|
||||
def test_pipeline_template_retrieval_base_requires_abstract_methods() -> None:
|
||||
assert "get_type" in MissingTypeRetrieval.__abstractmethods__
|
||||
|
||||
|
||||
def test_pipeline_template_retrieval_base_default_methods_raise() -> None:
|
||||
with pytest.raises(NotImplementedError):
|
||||
PipelineTemplateRetrievalBase.get_pipeline_templates(DummyRetrieval(), "en-US")
|
||||
with pytest.raises(NotImplementedError):
|
||||
PipelineTemplateRetrievalBase.get_pipeline_template_detail(DummyRetrieval(), "tpl-1")
|
||||
with pytest.raises(NotImplementedError):
|
||||
PipelineTemplateRetrievalBase.get_type(DummyRetrieval())
|
||||
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
|
||||
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_cls"),
|
||||
[
|
||||
(PipelineTemplateType.REMOTE, RemotePipelineTemplateRetrieval),
|
||||
(PipelineTemplateType.CUSTOMIZED, CustomizedPipelineTemplateRetrieval),
|
||||
(PipelineTemplateType.DATABASE, DatabasePipelineTemplateRetrieval),
|
||||
(PipelineTemplateType.BUILTIN, BuiltInPipelineTemplateRetrieval),
|
||||
],
|
||||
)
|
||||
def test_get_pipeline_template_factory(mode: str, expected_cls: type) -> None:
|
||||
result = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
||||
|
||||
assert result is expected_cls
|
||||
|
||||
|
||||
def test_get_pipeline_template_factory_invalid_mode() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
PipelineTemplateRetrievalFactory.get_pipeline_template_factory("invalid")
|
||||
|
||||
|
||||
def test_get_built_in_pipeline_template_retrieval() -> None:
|
||||
result = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||
|
||||
assert result is BuiltInPipelineTemplateRetrieval
|
||||
@@ -0,0 +1,8 @@
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
||||
|
||||
def test_pipeline_template_type_values() -> None:
|
||||
assert PipelineTemplateType.REMOTE == "remote"
|
||||
assert PipelineTemplateType.DATABASE == "database"
|
||||
assert PipelineTemplateType.CUSTOMIZED == "customized"
|
||||
assert PipelineTemplateType.BUILTIN == "builtin"
|
||||
@@ -0,0 +1,98 @@
|
||||
import pytest
|
||||
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval
|
||||
|
||||
|
||||
def test_get_pipeline_templates_fallbacks_to_database_on_error(mocker) -> None:
|
||||
fetch_mock = mocker.patch.object(
|
||||
RemotePipelineTemplateRetrieval,
|
||||
"fetch_pipeline_templates_from_dify_official",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
fallback_mock = mocker.patch.object(
|
||||
DatabasePipelineTemplateRetrieval,
|
||||
"fetch_pipeline_templates_from_db",
|
||||
return_value={"pipeline_templates": [{"id": "db-1"}]},
|
||||
)
|
||||
retrieval = RemotePipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_templates("en-US")
|
||||
|
||||
assert retrieval.get_type() == PipelineTemplateType.REMOTE
|
||||
assert result == {"pipeline_templates": [{"id": "db-1"}]}
|
||||
fetch_mock.assert_called_once_with("en-US")
|
||||
fallback_mock.assert_called_once_with("en-US")
|
||||
|
||||
|
||||
def test_get_pipeline_template_detail_fallbacks_to_database_on_error(mocker) -> None:
|
||||
fetch_mock = mocker.patch.object(
|
||||
RemotePipelineTemplateRetrieval,
|
||||
"fetch_pipeline_template_detail_from_dify_official",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
fallback_mock = mocker.patch.object(
|
||||
DatabasePipelineTemplateRetrieval,
|
||||
"fetch_pipeline_template_detail_from_db",
|
||||
return_value={"id": "db-1"},
|
||||
)
|
||||
retrieval = RemotePipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_template_detail("tpl-1")
|
||||
|
||||
assert result == {"id": "db-1"}
|
||||
fetch_mock.assert_called_once_with("tpl-1")
|
||||
fallback_mock.assert_called_once_with("tpl-1")
|
||||
|
||||
|
||||
def test_fetch_pipeline_templates_from_dify_official(mocker) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.remote.remote_retrieval"
|
||||
".dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN",
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
success_response = mocker.Mock(status_code=200)
|
||||
success_response.json.return_value = {"pipeline_templates": [{"id": "remote-1"}]}
|
||||
|
||||
failed_response = mocker.Mock(status_code=500)
|
||||
|
||||
http_get_mock = mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.remote.remote_retrieval.httpx.get",
|
||||
side_effect=[success_response, failed_response],
|
||||
)
|
||||
|
||||
success_result = RemotePipelineTemplateRetrieval.fetch_pipeline_templates_from_dify_official("en-US")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
RemotePipelineTemplateRetrieval.fetch_pipeline_templates_from_dify_official("en-US")
|
||||
|
||||
assert success_result == {"pipeline_templates": [{"id": "remote-1"}]}
|
||||
assert http_get_mock.call_count == 2
|
||||
|
||||
|
||||
def test_fetch_pipeline_template_detail_from_dify_official(mocker) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.remote.remote_retrieval"
|
||||
".dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN",
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
success_response = mocker.Mock(status_code=200)
|
||||
success_response.json.return_value = {"id": "remote-1", "name": "Remote Template"}
|
||||
|
||||
failed_response = mocker.Mock(status_code=404)
|
||||
failed_response.text = "Not Found"
|
||||
|
||||
http_get_mock = mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.remote.remote_retrieval.httpx.get",
|
||||
side_effect=[success_response, failed_response],
|
||||
)
|
||||
|
||||
success_result = RemotePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_dify_official("remote-1")
|
||||
with pytest.raises(ValueError):
|
||||
RemotePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_dify_official("missing")
|
||||
|
||||
assert success_result == {"id": "remote-1", "name": "Remote Template"}
|
||||
assert http_get_mock.call_count == 2
|
||||
@@ -0,0 +1,155 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.dataset import Pipeline
|
||||
from models.model import Account, App, EndUser
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
|
||||
|
||||
def test_get_max_active_requests_uses_smallest_non_zero_limit(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.pipeline_generate_service.dify_config.APP_DEFAULT_ACTIVE_REQUESTS", 5)
|
||||
mocker.patch("services.rag_pipeline.pipeline_generate_service.dify_config.APP_MAX_ACTIVE_REQUESTS", 3)
|
||||
|
||||
app_model = cast(App, SimpleNamespace(max_active_requests=10))
|
||||
|
||||
result = PipelineGenerateService._get_max_active_requests(app_model)
|
||||
|
||||
assert result == 3
|
||||
|
||||
|
||||
def test_get_max_active_requests_returns_zero_when_all_unlimited(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.pipeline_generate_service.dify_config.APP_DEFAULT_ACTIVE_REQUESTS", 0)
|
||||
mocker.patch("services.rag_pipeline.pipeline_generate_service.dify_config.APP_MAX_ACTIVE_REQUESTS", 0)
|
||||
|
||||
app_model = cast(App, SimpleNamespace(max_active_requests=0))
|
||||
|
||||
result = PipelineGenerateService._get_max_active_requests(app_model)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invoke_from", "workflow", "expected_error"),
|
||||
[
|
||||
(InvokeFrom.DEBUGGER, None, "Workflow not initialized"),
|
||||
(InvokeFrom.WEB_APP, None, "Workflow not published"),
|
||||
(InvokeFrom.DEBUGGER, SimpleNamespace(id="wf-1"), None),
|
||||
],
|
||||
)
|
||||
def test_get_workflow(mocker, invoke_from, workflow, expected_error) -> None:
|
||||
rag_pipeline_service_cls = mocker.patch("services.rag_pipeline.pipeline_generate_service.RagPipelineService")
|
||||
rag_pipeline_service = rag_pipeline_service_cls.return_value
|
||||
rag_pipeline_service.get_draft_workflow.return_value = workflow
|
||||
rag_pipeline_service.get_published_workflow.return_value = workflow
|
||||
|
||||
pipeline = cast(Pipeline, SimpleNamespace(id="pipeline-1"))
|
||||
|
||||
if expected_error:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
PipelineGenerateService._get_workflow(pipeline, invoke_from)
|
||||
else:
|
||||
result = PipelineGenerateService._get_workflow(pipeline, invoke_from)
|
||||
assert result == workflow
|
||||
|
||||
|
||||
def test_generate_updates_document_status_and_returns_event_stream(mocker) -> None:
|
||||
pipeline = cast(Pipeline, SimpleNamespace(id="pipeline-1"))
|
||||
user = cast(Account | EndUser, SimpleNamespace(id="user-1"))
|
||||
args = {"original_document_id": "doc-1", "query": "hello"}
|
||||
|
||||
mocker.patch.object(PipelineGenerateService, "_get_workflow", return_value=SimpleNamespace(id="wf-1"))
|
||||
update_status_mock = mocker.patch.object(PipelineGenerateService, "update_document_status")
|
||||
|
||||
generator_cls = mocker.patch("services.rag_pipeline.pipeline_generate_service.PipelineGenerator")
|
||||
generator_instance = generator_cls.return_value
|
||||
generator_instance.generate.return_value = "raw-events"
|
||||
generator_cls.convert_to_event_stream.return_value = "stream-events"
|
||||
|
||||
result = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "stream-events"
|
||||
update_status_mock.assert_called_once_with("doc-1")
|
||||
|
||||
|
||||
def test_update_document_status_updates_existing_document(mocker) -> None:
|
||||
document = SimpleNamespace(indexing_status="completed")
|
||||
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = document
|
||||
add_mock = session_mock.add
|
||||
commit_mock = session_mock.commit
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_generate_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
PipelineGenerateService.update_document_status("doc-1")
|
||||
|
||||
assert document.indexing_status == "waiting"
|
||||
add_mock.assert_called_once_with(document)
|
||||
commit_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_update_document_status_skips_when_document_missing(mocker) -> None:
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = None
|
||||
add_mock = session_mock.add
|
||||
commit_mock = session_mock.commit
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_generate_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
PipelineGenerateService.update_document_status("missing")
|
||||
|
||||
add_mock.assert_not_called()
|
||||
commit_mock.assert_not_called()
|
||||
|
||||
|
||||
# --- generate_single_iteration ---
|
||||
|
||||
|
||||
def test_generate_single_iteration_delegates(mocker) -> None:
|
||||
mocker.patch.object(PipelineGenerateService, "_get_workflow", return_value=SimpleNamespace(id="wf-1"))
|
||||
|
||||
generator_cls = mocker.patch("services.rag_pipeline.pipeline_generate_service.PipelineGenerator")
|
||||
generator_instance = generator_cls.return_value
|
||||
generator_instance.single_iteration_generate.return_value = "raw-iter"
|
||||
generator_cls.convert_to_event_stream.return_value = "stream-iter"
|
||||
|
||||
pipeline = cast(Pipeline, SimpleNamespace(id="p1"))
|
||||
user = cast(Account, SimpleNamespace(id="u1"))
|
||||
|
||||
result = PipelineGenerateService.generate_single_iteration(pipeline, user, "node-1", {"key": "val"})
|
||||
|
||||
assert result == "stream-iter"
|
||||
generator_instance.single_iteration_generate.assert_called_once()
|
||||
|
||||
|
||||
# --- generate_single_loop ---
|
||||
|
||||
|
||||
def test_generate_single_loop_delegates(mocker) -> None:
|
||||
mocker.patch.object(PipelineGenerateService, "_get_workflow", return_value=SimpleNamespace(id="wf-1"))
|
||||
|
||||
generator_cls = mocker.patch("services.rag_pipeline.pipeline_generate_service.PipelineGenerator")
|
||||
generator_instance = generator_cls.return_value
|
||||
generator_instance.single_loop_generate.return_value = "raw-loop"
|
||||
generator_cls.convert_to_event_stream.return_value = "stream-loop"
|
||||
|
||||
pipeline = cast(Pipeline, SimpleNamespace(id="p1"))
|
||||
user = cast(Account, SimpleNamespace(id="u1"))
|
||||
|
||||
result = PipelineGenerateService.generate_single_loop(pipeline, user, "node-1", {"key": "val"})
|
||||
|
||||
assert result == "stream-loop"
|
||||
generator_instance.single_loop_generate.assert_called_once()
|
||||
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import (
|
||||
DatasourceNodeRunApiEntity,
|
||||
PipelineRunApiEntity,
|
||||
)
|
||||
|
||||
|
||||
def test_datasource_node_run_api_entity_valid_payload() -> None:
|
||||
entity = DatasourceNodeRunApiEntity(
|
||||
pipeline_id="pipeline-1",
|
||||
node_id="node-1",
|
||||
inputs={"q": "hello"},
|
||||
datasource_type="local_file",
|
||||
credential_id="cred-1",
|
||||
is_published=True,
|
||||
)
|
||||
|
||||
assert entity.pipeline_id == "pipeline-1"
|
||||
assert entity.credential_id == "cred-1"
|
||||
|
||||
|
||||
def test_pipeline_run_api_entity_requires_start_node_id() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PipelineRunApiEntity.model_validate(
|
||||
{
|
||||
"inputs": {"q": "hello"},
|
||||
"datasource_type": "local_file",
|
||||
"datasource_info_list": [{"id": "ds-1"}],
|
||||
"is_published": True,
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||
|
||||
|
||||
def test_list_rag_pipeline_datasources_marks_authorized(mocker) -> None:
|
||||
datasource_1 = SimpleNamespace(provider="notion", plugin_id="plugin-1", is_authorized=False)
|
||||
datasource_2 = SimpleNamespace(provider="jina", plugin_id="plugin-2", is_authorized=False)
|
||||
|
||||
manager_cls = mocker.patch("services.rag_pipeline.rag_pipeline_manage_service.PluginDatasourceManager")
|
||||
manager_cls.return_value.fetch_datasource_providers.return_value = [datasource_1, datasource_2]
|
||||
|
||||
provider_cls = mocker.patch("services.rag_pipeline.rag_pipeline_manage_service.DatasourceProviderService")
|
||||
provider_instance = provider_cls.return_value
|
||||
provider_instance.get_datasource_credentials.side_effect = [
|
||||
{"access_token": "token"},
|
||||
None,
|
||||
]
|
||||
|
||||
result = RagPipelineManageService.list_rag_pipeline_datasources("tenant-1")
|
||||
|
||||
assert result == [datasource_1, datasource_2]
|
||||
assert datasource_1.is_authorized is True
|
||||
assert datasource_2.is_authorized is False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,159 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proxy(mocker):
|
||||
"""Create a RagPipelineTaskProxy with mocked dependencies."""
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_task_proxy.TenantIsolatedTaskQueue")
|
||||
entity = Mock()
|
||||
entity.model_dump.return_value = {"doc": "data"}
|
||||
return RagPipelineTaskProxy(
|
||||
dataset_tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
rag_pipeline_invoke_entities=[entity],
|
||||
)
|
||||
|
||||
|
||||
# --- delay ---
|
||||
|
||||
|
||||
def test_delay_with_empty_entities_logs_warning_and_returns(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_task_proxy.TenantIsolatedTaskQueue")
|
||||
proxy = RagPipelineTaskProxy(
|
||||
dataset_tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
rag_pipeline_invoke_entities=[],
|
||||
)
|
||||
dispatch_mock = mocker.patch.object(proxy, "_dispatch")
|
||||
|
||||
proxy.delay()
|
||||
|
||||
dispatch_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_delay_with_entities_calls_dispatch(mocker, proxy) -> None:
|
||||
dispatch_mock = mocker.patch.object(proxy, "_dispatch")
|
||||
|
||||
proxy.delay()
|
||||
|
||||
dispatch_mock.assert_called_once()
|
||||
|
||||
|
||||
# --- _dispatch ---
|
||||
|
||||
|
||||
def test_dispatch_billing_sandbox_uses_default_tenant_queue(mocker, proxy) -> None:
|
||||
upload_mock = mocker.patch.object(proxy, "_upload_invoke_entities", return_value="file-1")
|
||||
send_mock = mocker.patch.object(proxy, "_send_to_default_tenant_queue")
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.SANDBOX))
|
||||
)
|
||||
mocker.patch.object(type(proxy), "features", new_callable=lambda: property(lambda self: features))
|
||||
|
||||
proxy._dispatch()
|
||||
|
||||
upload_mock.assert_called_once()
|
||||
send_mock.assert_called_once_with("file-1")
|
||||
|
||||
|
||||
def test_dispatch_billing_non_sandbox_uses_priority_tenant_queue(mocker, proxy) -> None:
|
||||
upload_mock = mocker.patch.object(proxy, "_upload_invoke_entities", return_value="file-1")
|
||||
send_mock = mocker.patch.object(proxy, "_send_to_priority_tenant_queue")
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL))
|
||||
)
|
||||
mocker.patch.object(type(proxy), "features", new_callable=lambda: property(lambda self: features))
|
||||
|
||||
proxy._dispatch()
|
||||
|
||||
upload_mock.assert_called_once()
|
||||
send_mock.assert_called_once_with("file-1")
|
||||
|
||||
|
||||
def test_dispatch_no_billing_uses_priority_direct_queue(mocker, proxy) -> None:
|
||||
upload_mock = mocker.patch.object(proxy, "_upload_invoke_entities", return_value="file-1")
|
||||
send_mock = mocker.patch.object(proxy, "_send_to_priority_direct_queue")
|
||||
|
||||
features = SimpleNamespace(billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="free")))
|
||||
mocker.patch.object(type(proxy), "features", new_callable=lambda: property(lambda self: features))
|
||||
|
||||
proxy._dispatch()
|
||||
|
||||
upload_mock.assert_called_once()
|
||||
send_mock.assert_called_once_with("file-1")
|
||||
|
||||
|
||||
def test_dispatch_raises_on_empty_upload_file_id(mocker, proxy) -> None:
|
||||
mocker.patch.object(proxy, "_upload_invoke_entities", return_value="")
|
||||
|
||||
features = SimpleNamespace(billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="free")))
|
||||
mocker.patch.object(type(proxy), "features", new_callable=lambda: property(lambda self: features))
|
||||
|
||||
with pytest.raises(ValueError, match="upload_file_id is empty"):
|
||||
proxy._dispatch()
|
||||
|
||||
|
||||
# --- _send_to_direct_queue ---
|
||||
|
||||
|
||||
def test_send_to_direct_queue_calls_task_func_delay(mocker, proxy) -> None:
|
||||
task_func = Mock()
|
||||
|
||||
proxy._send_to_direct_queue("file-1", task_func)
|
||||
|
||||
task_func.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id="file-1",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
|
||||
# --- _send_to_tenant_queue ---
|
||||
|
||||
|
||||
def test_send_to_tenant_queue_pushes_when_task_key_exists(mocker, proxy) -> None:
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
|
||||
task_func = Mock()
|
||||
|
||||
proxy._send_to_tenant_queue("file-1", task_func)
|
||||
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with(["file-1"])
|
||||
task_func.delay.assert_not_called()
|
||||
|
||||
|
||||
def test_send_to_tenant_queue_sets_waiting_time_and_calls_delay(mocker, proxy) -> None:
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
|
||||
task_func = Mock()
|
||||
|
||||
proxy._send_to_tenant_queue("file-1", task_func)
|
||||
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
task_func.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id="file-1",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
|
||||
# --- _upload_invoke_entities ---
|
||||
|
||||
|
||||
def test_upload_invoke_entities_returns_file_id(mocker, proxy) -> None:
|
||||
upload_file = SimpleNamespace(id="uploaded-file-1")
|
||||
file_service_cls = mocker.patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
file_service_cls.return_value.upload_text.return_value = upload_file
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_task_proxy.db", mocker.Mock(engine="fake-engine"))
|
||||
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
assert result == "uploaded-file-1"
|
||||
file_service_cls.return_value.upload_text.assert_called_once()
|
||||
@@ -0,0 +1,516 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
|
||||
from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTransformService
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("doc_form", "datasource_type", "indexing_technique"),
|
||||
[
|
||||
("text_model", "upload_file", "high_quality"),
|
||||
("text_model", "upload_file", "economy"),
|
||||
("text_model", "notion_import", "high_quality"),
|
||||
("text_model", "notion_import", "economy"),
|
||||
("text_model", "website_crawl", "high_quality"),
|
||||
("text_model", "website_crawl", "economy"),
|
||||
("hierarchical_model", "upload_file", None),
|
||||
("hierarchical_model", "notion_import", None),
|
||||
("hierarchical_model", "website_crawl", None),
|
||||
],
|
||||
)
|
||||
def test_get_transform_yaml_returns_workflow(doc_form: str, datasource_type: str, indexing_technique: str | None):
|
||||
service = RagPipelineTransformService()
|
||||
|
||||
result = service._get_transform_yaml(doc_form, datasource_type, indexing_technique)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "workflow" in result
|
||||
|
||||
|
||||
def test_get_transform_yaml_raises_for_unsupported_doc_form() -> None:
|
||||
service = RagPipelineTransformService()
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported doc form"):
|
||||
service._get_transform_yaml("unknown", "upload_file", "high_quality")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("doc_form", ["text_model", "hierarchical_model"])
|
||||
def test_get_transform_yaml_raises_for_unsupported_datasource_type(doc_form: str) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported datasource type"):
|
||||
service._get_transform_yaml(doc_form, "unsupported", "high_quality")
|
||||
|
||||
|
||||
def test_deal_file_extensions_filters_and_normalizes_extensions() -> None:
|
||||
service = RagPipelineTransformService()
|
||||
node = {"data": {"fileExtensions": ["pdf", "TXT", "exe"]}}
|
||||
|
||||
result = service._deal_file_extensions(node)
|
||||
|
||||
assert result["data"]["fileExtensions"] == ["pdf", "txt"]
|
||||
|
||||
|
||||
def test_deal_file_extensions_returns_original_when_empty() -> None:
|
||||
service = RagPipelineTransformService()
|
||||
node = {"data": {"fileExtensions": []}}
|
||||
|
||||
result = service._deal_file_extensions(node)
|
||||
|
||||
assert result is node
|
||||
|
||||
|
||||
def test_deal_dependencies_installs_missing_marketplace_plugins(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
|
||||
installer_cls = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginInstaller")
|
||||
installer_cls.return_value.list_plugins.return_value = [SimpleNamespace(plugin_id="installed-plugin")]
|
||||
|
||||
migration_cls = mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.PluginMigration")
|
||||
migration_cls.return_value._fetch_plugin_unique_identifier.return_value = "missing-plugin:1.0.0"
|
||||
|
||||
install_mock = mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.PluginService.install_from_marketplace_pkg"
|
||||
)
|
||||
|
||||
pipeline_yaml = {
|
||||
"dependencies": [
|
||||
{"type": "marketplace", "value": {"plugin_unique_identifier": "installed-plugin:0.1.0"}},
|
||||
{"type": "marketplace", "value": {"plugin_unique_identifier": "missing-plugin:0.1.0"}},
|
||||
]
|
||||
}
|
||||
|
||||
service._deal_dependencies(pipeline_yaml, "tenant-1")
|
||||
|
||||
install_mock.assert_called_once_with("tenant-1", ["missing-plugin:1.0.0"])
|
||||
|
||||
|
||||
def test_transform_to_empty_pipeline_updates_dataset_and_commits(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.current_user",
|
||||
SimpleNamespace(id="user-1"),
|
||||
)
|
||||
|
||||
class FakePipeline:
|
||||
def __init__(self, **kwargs):
|
||||
self.id = "pipeline-1"
|
||||
self.tenant_id = kwargs["tenant_id"]
|
||||
self.name = kwargs["name"]
|
||||
self.description = kwargs["description"]
|
||||
self.created_by = kwargs["created_by"]
|
||||
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.Pipeline", FakePipeline)
|
||||
session_mock = mocker.Mock()
|
||||
add_mock = session_mock.add
|
||||
flush_mock = session_mock.flush
|
||||
commit_mock = session_mock.commit
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
name="Dataset",
|
||||
description="desc",
|
||||
pipeline_id=None,
|
||||
runtime_mode="general",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
result = service._transform_to_empty_pipeline(cast(Dataset, dataset))
|
||||
|
||||
assert result == {"pipeline_id": "pipeline-1", "dataset_id": "dataset-1", "status": "success"}
|
||||
assert dataset.pipeline_id == "pipeline-1"
|
||||
assert dataset.runtime_mode == "rag_pipeline"
|
||||
assert dataset.updated_by == "user-1"
|
||||
add_mock.assert_called()
|
||||
flush_mock.assert_called_once()
|
||||
commit_mock.assert_called_once()
|
||||
|
||||
|
||||
# --- transform_dataset ---
|
||||
|
||||
|
||||
def test_transform_dataset_returns_early_when_pipeline_exists(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
pipeline_id="p1",
|
||||
runtime_mode="rag_pipeline",
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
|
||||
assert result == {"pipeline_id": "p1", "dataset_id": "d1", "status": "success"}
|
||||
|
||||
|
||||
def test_transform_dataset_raises_for_dataset_not_found(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = None
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
service.transform_dataset("d1")
|
||||
|
||||
|
||||
def test_transform_dataset_raises_for_external_dataset(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
pipeline_id=None,
|
||||
runtime_mode=None,
|
||||
provider="external",
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="External dataset is not supported"):
|
||||
service.transform_dataset("d1")
|
||||
|
||||
|
||||
def test_transform_dataset_calls_empty_pipeline_when_no_datasource(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
pipeline_id=None,
|
||||
runtime_mode=None,
|
||||
provider="vendor",
|
||||
data_source_type=None,
|
||||
indexing_technique=None,
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"}
|
||||
mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
|
||||
assert result == empty_result
|
||||
|
||||
|
||||
def test_transform_dataset_calls_empty_pipeline_when_no_doc_form(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
pipeline_id=None,
|
||||
runtime_mode=None,
|
||||
provider="vendor",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
doc_form=None,
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
empty_result = {"pipeline_id": "p-empty", "dataset_id": "d1", "status": "success"}
|
||||
mocker.patch.object(service, "_transform_to_empty_pipeline", return_value=empty_result)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
|
||||
assert result == empty_result
|
||||
|
||||
|
||||
# --- _deal_knowledge_index ---
|
||||
|
||||
|
||||
def test_deal_knowledge_index_high_quality_sets_embedding(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = cast(
|
||||
Dataset,
|
||||
SimpleNamespace(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
retrieval_model=None,
|
||||
summary_index_setting=None,
|
||||
),
|
||||
)
|
||||
node = {
|
||||
"data": {
|
||||
"type": "knowledge-index",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "",
|
||||
"embedding_model_provider": "",
|
||||
"retrieval_model": {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"reranking_mode": None,
|
||||
"reranking_model": None,
|
||||
"weights": None,
|
||||
"top_k": 3,
|
||||
"score_threshold_enabled": False,
|
||||
"score_threshold": None,
|
||||
},
|
||||
"chunk_structure": "text_model",
|
||||
"keyword_number": None,
|
||||
"summary_index_setting": None,
|
||||
}
|
||||
}
|
||||
|
||||
# Create KnowledgeConfiguration from node data
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
retrieval_model = knowledge_configuration.retrieval_model
|
||||
|
||||
result = service._deal_knowledge_index(
|
||||
knowledge_configuration,
|
||||
dataset,
|
||||
"high_quality",
|
||||
retrieval_model,
|
||||
node,
|
||||
)
|
||||
|
||||
assert result["data"]["embedding_model"] == "text-embedding-ada-002"
|
||||
assert result["data"]["embedding_model_provider"] == "openai"
|
||||
|
||||
|
||||
# --- _deal_document_data ---
|
||||
|
||||
|
||||
def test_deal_document_data_notion(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
|
||||
doc = SimpleNamespace(
|
||||
id="doc1",
|
||||
dataset_id="d1",
|
||||
data_source_type="notion_import",
|
||||
data_source_info_dict={
|
||||
"notion_workspace_id": "ws1",
|
||||
"notion_page_id": "page1",
|
||||
"notion_page_icon": "icon1",
|
||||
"type": "page",
|
||||
"last_edited_time": 12345,
|
||||
},
|
||||
name="Notion Doc",
|
||||
created_by="u1",
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
data_source_info=None,
|
||||
)
|
||||
|
||||
scalars_mock = mocker.Mock()
|
||||
scalars_mock.all.return_value = [doc]
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
add_mock = session_mock.add
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
service._deal_document_data(cast(Dataset, dataset))
|
||||
|
||||
assert doc.data_source_type == "online_document"
|
||||
assert "page1" in doc.data_source_info
|
||||
assert add_mock.call_count == 2 # document + log
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("provider", "node_id"), [("firecrawl", "1752565402678"), ("jinareader", "1752491761974")])
|
||||
def test_deal_document_data_website(mocker, provider: str, node_id: str) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
|
||||
doc = SimpleNamespace(
|
||||
id="doc1",
|
||||
dataset_id="d1",
|
||||
data_source_type="website_crawl",
|
||||
data_source_info_dict={
|
||||
"url": "https://example.com",
|
||||
"provider": provider,
|
||||
},
|
||||
name="Web Doc",
|
||||
created_by="u1",
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
data_source_info=None,
|
||||
)
|
||||
|
||||
scalars_mock = mocker.Mock()
|
||||
scalars_mock.all.return_value = [doc]
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
add_mock = session_mock.add
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
service._deal_document_data(cast(Dataset, dataset))
|
||||
|
||||
assert doc.data_source_type == "website_crawl"
|
||||
assert "example.com" in doc.data_source_info
|
||||
# Check if correct node id was used in log
|
||||
log = add_mock.call_args_list[1][0][0]
|
||||
assert log.datasource_node_id == node_id
|
||||
|
||||
|
||||
# --- transform_dataset complex flow ---
|
||||
|
||||
|
||||
def test_transform_dataset_full_flow(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
tenant_id="t1",
|
||||
name="D",
|
||||
description="d",
|
||||
pipeline_id=None,
|
||||
runtime_mode=None,
|
||||
provider="vendor",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
doc_form="text_model",
|
||||
retrieval_model={"search_method": "semantic_search", "top_k": 3},
|
||||
embedding_model="m1",
|
||||
embedding_model_provider="p1",
|
||||
summary_index_setting=None,
|
||||
chunk_structure=None,
|
||||
)
|
||||
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
mocker.patch.object(service, "_deal_dependencies")
|
||||
mocker.patch.object(service, "_deal_document_data")
|
||||
session_mock.commit = mocker.Mock()
|
||||
|
||||
# Mock current_user to have the same tenant_id as dataset
|
||||
mock_current_user = SimpleNamespace(current_tenant_id="t1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_transform_service.current_user", mock_current_user)
|
||||
|
||||
pipeline = SimpleNamespace(id="p-new")
|
||||
mocker.patch.object(service, "_create_pipeline", return_value=pipeline)
|
||||
|
||||
result = service.transform_dataset("d1")
|
||||
|
||||
assert result["pipeline_id"] == "p-new"
|
||||
assert dataset.runtime_mode == "rag_pipeline"
|
||||
assert dataset.chunk_structure == "text_model"
|
||||
|
||||
|
||||
def test_transform_dataset_raises_for_unsupported_doc_form_after_pipeline_create(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
tenant_id="t1",
|
||||
name="D",
|
||||
description="d",
|
||||
pipeline_id=None,
|
||||
runtime_mode=None,
|
||||
provider="vendor",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
doc_form="unsupported",
|
||||
retrieval_model=None,
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
mocker.patch.object(service, "_get_transform_yaml", return_value={"workflow": {"graph": {"nodes": []}}})
|
||||
mocker.patch.object(service, "_deal_dependencies")
|
||||
mocker.patch.object(service, "_create_pipeline", return_value=SimpleNamespace(id="p-new"))
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported doc form"):
|
||||
service.transform_dataset("d1")
|
||||
|
||||
|
||||
def test_transform_dataset_raises_when_transform_yaml_missing_workflow(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(
|
||||
id="d1",
|
||||
tenant_id="t1",
|
||||
name="D",
|
||||
description="d",
|
||||
pipeline_id=None,
|
||||
runtime_mode=None,
|
||||
provider="vendor",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
doc_form="text_model",
|
||||
retrieval_model=None,
|
||||
)
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.get.return_value = dataset
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
mocker.patch.object(service, "_get_transform_yaml", return_value={})
|
||||
mocker.patch.object(service, "_deal_dependencies")
|
||||
|
||||
with pytest.raises(ValueError, match="Missing workflow data for rag pipeline"):
|
||||
service.transform_dataset("d1")
|
||||
|
||||
|
||||
def test_create_pipeline_raises_when_workflow_data_missing() -> None:
|
||||
service = RagPipelineTransformService()
|
||||
|
||||
with pytest.raises(ValueError, match="Missing workflow data for rag pipeline"):
|
||||
service._create_pipeline({"rag_pipeline": {"name": "N"}})
|
||||
|
||||
|
||||
def test_deal_document_data_upload_file_with_existing_file(mocker) -> None:
|
||||
service = RagPipelineTransformService()
|
||||
dataset = SimpleNamespace(id="d1", pipeline_id="p1")
|
||||
document = SimpleNamespace(
|
||||
id="doc-1",
|
||||
dataset_id="d1",
|
||||
data_source_type="upload_file",
|
||||
data_source_info_dict={"upload_file_id": "file-1"},
|
||||
name="Doc",
|
||||
created_by="u1",
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
data_source_info=None,
|
||||
)
|
||||
upload_file = SimpleNamespace(name="f.txt", size=10, extension="txt", mime_type="text/plain")
|
||||
|
||||
scalars_mock = mocker.Mock()
|
||||
scalars_mock.all.return_value = [document]
|
||||
session_mock = mocker.Mock()
|
||||
session_mock.scalars.return_value = scalars_mock
|
||||
session_mock.get.return_value = upload_file
|
||||
add_mock = session_mock.add
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_transform_service.db",
|
||||
new=SimpleNamespace(session=session_mock),
|
||||
)
|
||||
|
||||
service._deal_document_data(cast(Dataset, dataset))
|
||||
|
||||
assert document.data_source_type == "local_file"
|
||||
assert "real_file_id" in document.data_source_info
|
||||
assert add_mock.call_count >= 2
|
||||
Reference in New Issue
Block a user