mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 11:11:12 +08:00
fix: import path (#34124)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -10,7 +10,6 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
@@ -18,6 +17,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField
|
|||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
|
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from models import Account, TenantAccountRole
|
from models import Account, TenantAccountRole
|
||||||
from models.dataset import (
|
from models.dataset import (
|
||||||
ChildChunk,
|
ChildChunk,
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class TestDatasetServiceValidation:
|
|||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
|
|
||||||
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
|
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
@@ -201,7 +201,7 @@ class TestDatasetServiceValidation:
|
|||||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||||
|
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No Embedding Model available"):
|
with pytest.raises(ValueError, match="No Embedding Model available"):
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
@@ -210,14 +210,18 @@ class TestDatasetServiceValidation:
|
|||||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||||
|
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||||
|
"token missing"
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="token missing"):
|
with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"):
|
||||||
DatasetService.check_dataset_model_setting(dataset)
|
DatasetService.check_dataset_model_setting(dataset)
|
||||||
|
|
||||||
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
|
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup")
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||||
|
"provider setup"
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="provider setup"):
|
with pytest.raises(ValueError, match="provider setup"):
|
||||||
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
|
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
|
||||||
@@ -226,7 +230,7 @@ class TestDatasetServiceValidation:
|
|||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
||||||
|
|
||||||
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
|
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
|
||||||
tenant_id="tenant-1",
|
tenant_id="tenant-1",
|
||||||
provider="provider",
|
provider="provider",
|
||||||
model_type=ModelType.RERANK,
|
model_type=ModelType.RERANK,
|
||||||
@@ -235,7 +239,7 @@ class TestDatasetServiceValidation:
|
|||||||
|
|
||||||
def test_check_reranking_model_setting_wraps_bad_request(self):
|
def test_check_reranking_model_setting_wraps_bad_request(self):
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No Rerank Model available"):
|
with pytest.raises(ValueError, match="No Rerank Model available"):
|
||||||
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
||||||
@@ -251,7 +255,7 @@ class TestDatasetServiceValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
|
||||||
|
|
||||||
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||||
|
|
||||||
@@ -268,7 +272,7 @@ class TestDatasetServiceValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
|
||||||
|
|
||||||
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||||
|
|
||||||
@@ -284,14 +288,14 @@ class TestDatasetServiceValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Model schema not found"):
|
with pytest.raises(ValueError, match="Model schema not found"):
|
||||||
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||||
|
|
||||||
def test_check_is_multimodal_model_wraps_bad_request_error(self):
|
def test_check_is_multimodal_model_wraps_bad_request_error(self):
|
||||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No Model available"):
|
with pytest.raises(ValueError, match="No Model available"):
|
||||||
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||||
@@ -323,7 +327,7 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
||||||
):
|
):
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||||
model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model
|
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
|
||||||
|
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id="tenant-1",
|
tenant_id="tenant-1",
|
||||||
@@ -337,7 +341,7 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
assert dataset.embedding_model == "default-embedding"
|
assert dataset.embedding_model == "default-embedding"
|
||||||
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
|
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||||
assert dataset.provider == "vendor"
|
assert dataset.provider == "vendor"
|
||||||
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
|
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
|
||||||
tenant_id="tenant-1",
|
tenant_id="tenant-1",
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
)
|
||||||
@@ -365,7 +369,7 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
|
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
|
||||||
):
|
):
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id="tenant-1",
|
tenant_id="tenant-1",
|
||||||
@@ -804,7 +808,7 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
return_value=SimpleNamespace(id="binding-1"),
|
return_value=SimpleNamespace(id="binding-1"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
|
||||||
DatasetService._configure_embedding_model_for_high_quality(
|
DatasetService._configure_embedding_model_for_high_quality(
|
||||||
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
|
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
|
||||||
@@ -836,7 +840,7 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
patch("services.dataset_service.current_user", current_user),
|
patch("services.dataset_service.current_user", current_user),
|
||||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = error
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=message):
|
with pytest.raises(ValueError, match=message):
|
||||||
DatasetService._configure_embedding_model_for_high_quality(
|
DatasetService._configure_embedding_model_for_high_quality(
|
||||||
@@ -967,7 +971,7 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
return_value=SimpleNamespace(id="binding-2"),
|
return_value=SimpleNamespace(id="binding-2"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
|
||||||
provider="provider-two",
|
provider="provider-two",
|
||||||
model_name="embedding-model-two",
|
model_name="embedding-model-two",
|
||||||
)
|
)
|
||||||
@@ -1002,7 +1006,9 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
patch("services.dataset_service.current_user", current_user),
|
patch("services.dataset_service.current_user", current_user),
|
||||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||||
|
"token missing"
|
||||||
|
)
|
||||||
|
|
||||||
DatasetService._apply_new_embedding_settings(
|
DatasetService._apply_new_embedding_settings(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -1067,7 +1073,7 @@ class TestDatasetServiceRagPipelineSettings:
|
|||||||
return_value=SimpleNamespace(id="binding-1"),
|
return_value=SimpleNamespace(id="binding-1"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
|
||||||
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
|
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
|
||||||
|
|
||||||
@@ -1161,7 +1167,7 @@ class TestDatasetServiceRagPipelineSettings:
|
|||||||
),
|
),
|
||||||
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
|
||||||
DatasetService.update_rag_pipeline_dataset_settings(
|
DatasetService.update_rag_pipeline_dataset_settings(
|
||||||
session,
|
session,
|
||||||
@@ -1204,7 +1210,7 @@ class TestDatasetServiceRagPipelineSettings:
|
|||||||
),
|
),
|
||||||
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
|
||||||
provider="provider-two",
|
provider="provider-two",
|
||||||
model_name="embedding-model-two",
|
model_name="embedding-model-two",
|
||||||
)
|
)
|
||||||
@@ -1243,7 +1249,9 @@ class TestDatasetServiceRagPipelineSettings:
|
|||||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||||
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||||
|
"token missing"
|
||||||
|
)
|
||||||
|
|
||||||
DatasetService.update_rag_pipeline_dataset_settings(
|
DatasetService.update_rag_pipeline_dataset_settings(
|
||||||
session,
|
session,
|
||||||
|
|||||||
@@ -1828,7 +1828,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
|||||||
) as get_binding,
|
) as get_binding,
|
||||||
patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document),
|
patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document),
|
||||||
):
|
):
|
||||||
model_manager_cls.return_value.get_default_model_instance.return_value = SimpleNamespace(
|
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace(
|
||||||
model_name="default-embedding",
|
model_name="default-embedding",
|
||||||
provider="default-provider",
|
provider="default-provider",
|
||||||
)
|
)
|
||||||
@@ -1880,7 +1880,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
|||||||
):
|
):
|
||||||
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
||||||
|
|
||||||
model_manager_cls.return_value.get_default_model_instance.assert_not_called()
|
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called()
|
||||||
get_binding.assert_called_once_with("explicit-provider", "explicit-model")
|
get_binding.assert_called_once_with("explicit-provider", "explicit-model")
|
||||||
assert dataset.embedding_model == "explicit-model"
|
assert dataset.embedding_model == "explicit-model"
|
||||||
assert dataset.embedding_model_provider == "explicit-provider"
|
assert dataset.embedding_model_provider == "explicit-provider"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .dataset_service_test_helpers import (
|
|||||||
DocumentSegment,
|
DocumentSegment,
|
||||||
IndexStructureType,
|
IndexStructureType,
|
||||||
MagicMock,
|
MagicMock,
|
||||||
|
ModelType,
|
||||||
SegmentService,
|
SegmentService,
|
||||||
SegmentUpdateArgs,
|
SegmentUpdateArgs,
|
||||||
SimpleNamespace,
|
SimpleNamespace,
|
||||||
@@ -459,7 +460,7 @@ class TestSegmentServiceMutations:
|
|||||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||||
):
|
):
|
||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
|
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
|
||||||
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
|
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
|
||||||
|
|
||||||
@@ -571,7 +572,7 @@ class TestSegmentServiceMutations:
|
|||||||
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
||||||
):
|
):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
|
||||||
|
|
||||||
processing_rule_query = MagicMock()
|
processing_rule_query = MagicMock()
|
||||||
processing_rule_query.where.return_value.first.return_value = processing_rule
|
processing_rule_query.where.return_value.first.return_value = processing_rule
|
||||||
@@ -618,7 +619,7 @@ class TestSegmentServiceMutations:
|
|||||||
) as generate_summary,
|
) as generate_summary,
|
||||||
):
|
):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
|
||||||
summary_query = MagicMock()
|
summary_query = MagicMock()
|
||||||
summary_query.where.return_value.first.return_value = existing_summary
|
summary_query.where.return_value.first.return_value = existing_summary
|
||||||
@@ -661,7 +662,7 @@ class TestSegmentServiceMutations:
|
|||||||
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
||||||
):
|
):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
|
|
||||||
summary_query = MagicMock()
|
summary_query = MagicMock()
|
||||||
summary_query.where.return_value.first.return_value = existing_summary
|
summary_query.where.return_value.first.return_value = existing_summary
|
||||||
@@ -900,7 +901,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
|||||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||||
):
|
):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||||
summary_query = MagicMock()
|
summary_query = MagicMock()
|
||||||
summary_query.where.return_value.first.return_value = None
|
summary_query.where.return_value.first.return_value = None
|
||||||
refreshed_query = MagicMock()
|
refreshed_query = MagicMock()
|
||||||
@@ -947,7 +948,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
|||||||
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
||||||
):
|
):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance
|
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
|
||||||
update_summary.side_effect = RuntimeError("summary failed")
|
update_summary.side_effect = RuntimeError("summary failed")
|
||||||
|
|
||||||
processing_rule_query = MagicMock()
|
processing_rule_query = MagicMock()
|
||||||
@@ -966,9 +967,9 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result is refreshed_segment
|
assert result is refreshed_segment
|
||||||
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
|
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
|
||||||
tenant_id="tenant-1",
|
tenant_id="tenant-1",
|
||||||
model_type="text-embedding",
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
)
|
||||||
vector_service.generate_child_chunks.assert_called_once_with(
|
vector_service.generate_child_chunks.assert_called_once_with(
|
||||||
segment,
|
segment,
|
||||||
|
|||||||
Reference in New Issue
Block a user