fix: import path (#34124)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
QuantumGhost
2026-03-26 16:13:53 +08:00
committed by GitHub
parent 8ca54ddf94
commit e08c06cbc3
4 changed files with 42 additions and 33 deletions

View File

@@ -10,7 +10,6 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, create_autospec, patch from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest import pytest
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@@ -18,6 +17,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan from enums.cloud_plan import CloudPlan
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from models import Account, TenantAccountRole from models import Account, TenantAccountRole
from models.dataset import ( from models.dataset import (
ChildChunk, ChildChunk,

View File

@@ -190,7 +190,7 @@ class TestDatasetServiceValidation:
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
model_manager_cls.return_value.get_model_instance.assert_called_once_with( model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
@@ -201,7 +201,7 @@ class TestDatasetServiceValidation:
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError() model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Embedding Model available"): with pytest.raises(ValueError, match="No Embedding Model available"):
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
@@ -210,14 +210,18 @@ class TestDatasetServiceValidation:
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing") model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
with pytest.raises(ValueError, match="token missing"): with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"):
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
def test_check_embedding_model_setting_wraps_provider_token_error_description(self): def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup") model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"provider setup"
)
with pytest.raises(ValueError, match="provider setup"): with pytest.raises(ValueError, match="provider setup"):
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model") DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
@@ -226,7 +230,7 @@ class TestDatasetServiceValidation:
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
model_manager_cls.return_value.get_model_instance.assert_called_once_with( model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
tenant_id="tenant-1", tenant_id="tenant-1",
provider="provider", provider="provider",
model_type=ModelType.RERANK, model_type=ModelType.RERANK,
@@ -235,7 +239,7 @@ class TestDatasetServiceValidation:
def test_check_reranking_model_setting_wraps_bad_request(self): def test_check_reranking_model_setting_wraps_bad_request(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError() model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Rerank Model available"): with pytest.raises(ValueError, match="No Rerank Model available"):
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
@@ -251,7 +255,7 @@ class TestDatasetServiceValidation:
) )
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
@@ -268,7 +272,7 @@ class TestDatasetServiceValidation:
) )
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
@@ -284,14 +288,14 @@ class TestDatasetServiceValidation:
) )
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
with pytest.raises(ValueError, match="Model schema not found"): with pytest.raises(ValueError, match="Model schema not found"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
def test_check_is_multimodal_model_wraps_bad_request_error(self): def test_check_is_multimodal_model_wraps_bad_request_error(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls: with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError() model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Model available"): with pytest.raises(ValueError, match="No Model available"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
@@ -323,7 +327,7 @@ class TestDatasetServiceCreationAndUpdate:
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
): ):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1", tenant_id="tenant-1",
@@ -337,7 +341,7 @@ class TestDatasetServiceCreationAndUpdate:
assert dataset.embedding_model == "default-embedding" assert dataset.embedding_model == "default-embedding"
assert dataset.permission == DatasetPermissionEnum.ONLY_ME assert dataset.permission == DatasetPermissionEnum.ONLY_ME
assert dataset.provider == "vendor" assert dataset.provider == "vendor"
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with( model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1", tenant_id="tenant-1",
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
) )
@@ -365,7 +369,7 @@ class TestDatasetServiceCreationAndUpdate:
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking, patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
): ):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1", tenant_id="tenant-1",
@@ -804,7 +808,7 @@ class TestDatasetServiceCreationAndUpdate:
return_value=SimpleNamespace(id="binding-1"), return_value=SimpleNamespace(id="binding-1"),
), ),
): ):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService._configure_embedding_model_for_high_quality( DatasetService._configure_embedding_model_for_high_quality(
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, {"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
@@ -836,7 +840,7 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.current_user", current_user), patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls, patch("services.dataset_service.ModelManager") as model_manager_cls,
): ):
model_manager_cls.return_value.get_model_instance.side_effect = error model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error
with pytest.raises(ValueError, match=message): with pytest.raises(ValueError, match=message):
DatasetService._configure_embedding_model_for_high_quality( DatasetService._configure_embedding_model_for_high_quality(
@@ -967,7 +971,7 @@ class TestDatasetServiceCreationAndUpdate:
return_value=SimpleNamespace(id="binding-2"), return_value=SimpleNamespace(id="binding-2"),
), ),
): ):
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace( model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two", provider="provider-two",
model_name="embedding-model-two", model_name="embedding-model-two",
) )
@@ -1002,7 +1006,9 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.current_user", current_user), patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls, patch("services.dataset_service.ModelManager") as model_manager_cls,
): ):
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing") model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
DatasetService._apply_new_embedding_settings( DatasetService._apply_new_embedding_settings(
dataset, dataset,
@@ -1067,7 +1073,7 @@ class TestDatasetServiceRagPipelineSettings:
return_value=SimpleNamespace(id="binding-1"), return_value=SimpleNamespace(id="binding-1"),
), ),
): ):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
@@ -1161,7 +1167,7 @@ class TestDatasetServiceRagPipelineSettings:
), ),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task, patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
): ):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings( DatasetService.update_rag_pipeline_dataset_settings(
session, session,
@@ -1204,7 +1210,7 @@ class TestDatasetServiceRagPipelineSettings:
), ),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task, patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
): ):
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace( model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two", provider="provider-two",
model_name="embedding-model-two", model_name="embedding-model-two",
) )
@@ -1243,7 +1249,9 @@ class TestDatasetServiceRagPipelineSettings:
patch("services.dataset_service.ModelManager") as model_manager_cls, patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.deal_dataset_index_update_task") as update_task, patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
): ):
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing") model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
DatasetService.update_rag_pipeline_dataset_settings( DatasetService.update_rag_pipeline_dataset_settings(
session, session,

View File

@@ -1828,7 +1828,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
) as get_binding, ) as get_binding,
patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document), patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document),
): ):
model_manager_cls.return_value.get_default_model_instance.return_value = SimpleNamespace( model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace(
model_name="default-embedding", model_name="default-embedding",
provider="default-provider", provider="default-provider",
) )
@@ -1880,7 +1880,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
): ):
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
model_manager_cls.return_value.get_default_model_instance.assert_not_called() model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called()
get_binding.assert_called_once_with("explicit-provider", "explicit-model") get_binding.assert_called_once_with("explicit-provider", "explicit-model")
assert dataset.embedding_model == "explicit-model" assert dataset.embedding_model == "explicit-model"
assert dataset.embedding_model_provider == "explicit-provider" assert dataset.embedding_model_provider == "explicit-provider"

View File

@@ -9,6 +9,7 @@ from .dataset_service_test_helpers import (
DocumentSegment, DocumentSegment,
IndexStructureType, IndexStructureType,
MagicMock, MagicMock,
ModelType,
SegmentService, SegmentService,
SegmentUpdateArgs, SegmentUpdateArgs,
SimpleNamespace, SimpleNamespace,
@@ -459,7 +460,7 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.naive_utc_now", return_value="now"), patch("services.dataset_service.naive_utc_now", return_value="now"),
): ):
mock_redis.lock.return_value = _make_lock_context() mock_redis.lock.return_value = _make_lock_context()
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
@@ -571,7 +572,7 @@ class TestSegmentServiceMutations:
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
processing_rule_query = MagicMock() processing_rule_query = MagicMock()
processing_rule_query.where.return_value.first.return_value = processing_rule processing_rule_query.where.return_value.first.return_value = processing_rule
@@ -618,7 +619,7 @@ class TestSegmentServiceMutations:
) as generate_summary, ) as generate_summary,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock() summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary summary_query.where.return_value.first.return_value = existing_summary
@@ -661,7 +662,7 @@ class TestSegmentServiceMutations:
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock() summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary summary_query.where.return_value.first.return_value = existing_summary
@@ -900,7 +901,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.dataset_service.naive_utc_now", return_value="now"), patch("services.dataset_service.naive_utc_now", return_value="now"),
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock() summary_query = MagicMock()
summary_query.where.return_value.first.return_value = None summary_query.where.return_value.first.return_value = None
refreshed_query = MagicMock() refreshed_query = MagicMock()
@@ -947,7 +948,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
update_summary.side_effect = RuntimeError("summary failed") update_summary.side_effect = RuntimeError("summary failed")
processing_rule_query = MagicMock() processing_rule_query = MagicMock()
@@ -966,9 +967,9 @@ class TestSegmentServiceAdditionalRegenerationBranches:
) )
assert result is refreshed_segment assert result is refreshed_segment
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with( model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1", tenant_id="tenant-1",
model_type="text-embedding", model_type=ModelType.TEXT_EMBEDDING,
) )
vector_service.generate_child_chunks.assert_called_once_with( vector_service.generate_child_chunks.assert_called_once_with(
segment, segment,