mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
refactor: core/app pipeline, core/datasource, and core/indexing_runner (#34359)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -345,7 +345,7 @@ def test_generate_raises_when_workflow_not_found(generator, mocker):
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.get.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@@ -80,9 +80,7 @@ def test_get_workflow_returns_workflow(mocker, runner):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
workflow = MagicMock(id="wf")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
|
||||
mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow)))
|
||||
|
||||
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
|
||||
|
||||
@@ -115,11 +113,8 @@ def test_init_rag_pipeline_graph_not_found(mocker, runner):
|
||||
def test_update_document_status_on_failure(mocker, runner):
|
||||
document = MagicMock()
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = document
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
event = GraphRunFailedEvent(error="boom")
|
||||
@@ -189,14 +184,10 @@ def test_run_single_iteration_path(mocker):
|
||||
app_generate_entity.single_iteration_run = MagicMock()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
end_user = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
session.get.side_effect = [end_user, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
@@ -241,14 +232,10 @@ def test_run_normal_path_builds_graph(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
end_user = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
session.get.side_effect = [end_user, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
workflow = MagicMock(
|
||||
|
||||
@@ -287,9 +287,7 @@ class TestDatasourceFileManager:
|
||||
mock_upload_file.key = "some_key"
|
||||
mock_upload_file.mime_type = "image/png"
|
||||
|
||||
mock_query = mock_db.session.query.return_value
|
||||
mock_where = mock_query.where.return_value
|
||||
mock_where.first.return_value = mock_upload_file
|
||||
mock_db.session.get.return_value = mock_upload_file
|
||||
|
||||
mock_storage.load_once.return_value = b"file content"
|
||||
|
||||
@@ -300,7 +298,7 @@ class TestDatasourceFileManager:
|
||||
assert result == (b"file content", "image/png")
|
||||
|
||||
# Case: Not found
|
||||
mock_where.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
assert DatasourceFileManager.get_file_binary("unknown") is None
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@@ -314,16 +312,14 @@ class TestDatasourceFileManager:
|
||||
mock_tool_file.file_key = "tool_key"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
|
||||
# Mock query sequence
|
||||
def mock_query(model):
|
||||
m = MagicMock()
|
||||
def mock_get(model, id):
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
return mock_message_file
|
||||
elif model == ToolFile:
|
||||
m.where.return_value.first.return_value = mock_tool_file
|
||||
return m
|
||||
return mock_tool_file
|
||||
return None
|
||||
|
||||
mock_db.session.query.side_effect = mock_query
|
||||
mock_db.session.get.side_effect = mock_get
|
||||
mock_storage.load_once.return_value = b"tool content"
|
||||
|
||||
# Execute
|
||||
@@ -344,15 +340,12 @@ class TestDatasourceFileManager:
|
||||
mock_tool_file.file_key = "tk"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
|
||||
def mock_query(model):
|
||||
m = MagicMock()
|
||||
def mock_get(model, id):
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
else:
|
||||
m.where.return_value.first.return_value = mock_tool_file
|
||||
return m
|
||||
return mock_message_file
|
||||
return mock_tool_file
|
||||
|
||||
mock_db.session.query.side_effect = mock_query
|
||||
mock_db.session.get.side_effect = mock_get
|
||||
mock_storage.load_once.return_value = b"bits"
|
||||
|
||||
result = DatasourceFileManager.get_file_binary_by_message_file_id("m")
|
||||
@@ -361,27 +354,20 @@ class TestDatasourceFileManager:
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db):
|
||||
# Setup common mock
|
||||
mock_query_obj = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query_obj
|
||||
mock_query_obj.where.return_value.first.return_value = None
|
||||
|
||||
# Case 1: Message file not found
|
||||
mock_db.session.get.return_value = None
|
||||
assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None
|
||||
|
||||
# Case 2: Message file found but tool file not found
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = None
|
||||
|
||||
def mock_query_v2(model):
|
||||
m = MagicMock()
|
||||
def mock_get_v2(model, id):
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
else:
|
||||
m.where.return_value.first.return_value = None
|
||||
return m
|
||||
return mock_message_file
|
||||
return None
|
||||
|
||||
mock_db.session.query.side_effect = mock_query_v2
|
||||
mock_db.session.get.side_effect = mock_get_v2
|
||||
assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@@ -392,7 +378,7 @@ class TestDatasourceFileManager:
|
||||
mock_upload_file.key = "upload_key"
|
||||
mock_upload_file.mime_type = "text/plain"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db.session.get.return_value = mock_upload_file
|
||||
|
||||
mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"])
|
||||
|
||||
@@ -404,7 +390,7 @@ class TestDatasourceFileManager:
|
||||
assert list(stream) == [b"chunk1", b"chunk2"]
|
||||
|
||||
# Case: Not found
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none")
|
||||
assert stream is None
|
||||
assert mimetype is None
|
||||
|
||||
@@ -795,33 +795,21 @@ class TestIndexingRunnerRun:
|
||||
doc = sample_dataset_documents[0]
|
||||
|
||||
# Mock database queries
|
||||
mock_dependencies["db"].session.get.return_value = doc
|
||||
|
||||
mock_dataset = Mock(spec=Dataset)
|
||||
mock_dataset.id = doc.dataset_id
|
||||
mock_dataset.tenant_id = doc.tenant_id
|
||||
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
|
||||
|
||||
mock_current_user = MagicMock()
|
||||
mock_current_user.set_tenant_id = MagicMock()
|
||||
|
||||
get_dispatch = {"Document": doc, "Dataset": mock_dataset, "Account": mock_current_user}
|
||||
mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__)
|
||||
|
||||
mock_process_rule = Mock(spec=DatasetProcessRule)
|
||||
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
||||
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
|
||||
|
||||
# Mock current_user (Account) for _transform
|
||||
mock_current_user = MagicMock()
|
||||
mock_current_user.set_tenant_id = MagicMock()
|
||||
|
||||
# Setup db.session.query to return different results based on the model
|
||||
def mock_query_side_effect(model):
|
||||
mock_query_result = MagicMock()
|
||||
if model.__name__ == "Dataset":
|
||||
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
|
||||
elif model.__name__ == "Account":
|
||||
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
|
||||
return mock_query_result
|
||||
|
||||
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Mock processor
|
||||
mock_processor = MagicMock()
|
||||
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
|
||||
@@ -891,10 +879,11 @@ class TestIndexingRunnerRun:
|
||||
doc = sample_dataset_documents[0]
|
||||
|
||||
# Mock database
|
||||
mock_dependencies["db"].session.get.return_value = doc
|
||||
|
||||
mock_dataset = Mock(spec=Dataset)
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
|
||||
mock_dataset.tenant_id = doc.tenant_id
|
||||
|
||||
get_dispatch = {"Document": doc, "Dataset": mock_dataset}
|
||||
mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__)
|
||||
|
||||
mock_process_rule = Mock(spec=DatasetProcessRule)
|
||||
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
||||
@@ -917,11 +906,12 @@ class TestIndexingRunnerRun:
|
||||
runner = IndexingRunner()
|
||||
doc = sample_dataset_documents[0]
|
||||
|
||||
# Mock database to raise ObjectDeletedError
|
||||
mock_dependencies["db"].session.get.return_value = doc
|
||||
|
||||
# Mock database
|
||||
mock_dataset = Mock(spec=Dataset)
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
|
||||
mock_dataset.tenant_id = doc.tenant_id
|
||||
|
||||
get_dispatch = {"Document": doc, "Dataset": mock_dataset}
|
||||
mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__)
|
||||
|
||||
mock_process_rule = Mock(spec=DatasetProcessRule)
|
||||
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
||||
@@ -945,17 +935,21 @@ class TestIndexingRunnerRun:
|
||||
docs = sample_dataset_documents
|
||||
|
||||
# Mock database
|
||||
def get_side_effect(model_class, doc_id):
|
||||
for doc in docs:
|
||||
if doc.id == doc_id:
|
||||
return doc
|
||||
return None
|
||||
|
||||
mock_dependencies["db"].session.get.side_effect = get_side_effect
|
||||
|
||||
mock_dataset = Mock(spec=Dataset)
|
||||
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
|
||||
mock_current_user = MagicMock()
|
||||
mock_current_user.set_tenant_id = MagicMock()
|
||||
|
||||
doc_map = {doc.id: doc for doc in docs}
|
||||
model_dispatch = {"Dataset": mock_dataset, "Account": mock_current_user}
|
||||
|
||||
def get_side_effect(model_class, id):
|
||||
name = model_class.__name__
|
||||
if name == "Document":
|
||||
return doc_map.get(id)
|
||||
return model_dispatch.get(name)
|
||||
|
||||
mock_dependencies["db"].session.get.side_effect = get_side_effect
|
||||
|
||||
mock_process_rule = Mock(spec=DatasetProcessRule)
|
||||
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
||||
@@ -1035,9 +1029,8 @@ class TestIndexingRunnerRetryLogic:
|
||||
mock_document = Mock(spec=DatasetDocument)
|
||||
mock_document.id = document_id
|
||||
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.update.return_value = None
|
||||
mock_dependencies["db"].session.scalar.return_value = 0
|
||||
mock_dependencies["db"].session.get.return_value = mock_document
|
||||
|
||||
# Act
|
||||
IndexingRunner._update_document_index_status(
|
||||
@@ -1053,7 +1046,7 @@ class TestIndexingRunnerRetryLogic:
|
||||
"""Test document status update when document is paused."""
|
||||
# Arrange
|
||||
document_id = str(uuid.uuid4())
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 1
|
||||
mock_dependencies["db"].session.scalar.return_value = 1
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DocumentIsPausedError):
|
||||
@@ -1063,8 +1056,8 @@ class TestIndexingRunnerRetryLogic:
|
||||
"""Test document status update when document is deleted."""
|
||||
# Arrange
|
||||
document_id = str(uuid.uuid4())
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0
|
||||
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_dependencies["db"].session.scalar.return_value = 0
|
||||
mock_dependencies["db"].session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DocumentIsDeletedPausedError):
|
||||
|
||||
Reference in New Issue
Block a user