From 5bafb163cc737f2c04d1693c227097924bbc5c74 Mon Sep 17 00:00:00 2001 From: Poojan Date: Thu, 2 Apr 2026 14:05:46 +0530 Subject: [PATCH] test: add unit tests for services and tasks part-4 (#33223) Co-authored-by: akashseth-ifp Co-authored-by: rajatagarwal-oss Co-authored-by: Dev Sharma <50591491+cryptus-neoxys@users.noreply.github.com> Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com> --- api/services/variable_truncator.py | 7 +- ...est_model_provider_service_sanitization.py | 641 ++++++++++++++ .../services/test_recommended_app_service.py | 198 ++++- .../services/test_schedule_service.py | 160 +++- .../services/test_variable_truncator.py | 259 +++++- .../services/test_webhook_service.py | 754 ++++++++++++++++ .../test_workflow_run_service_pause.py | 297 +++++++ .../test_workflow_converter_additional.py | 831 ++++++++++++++++++ .../test_workflow_event_snapshot_service.py | 577 +++++++++++- .../tasks/test_dataset_indexing_task.py | 490 ++++++++++- 10 files changed, 4180 insertions(+), 34 deletions(-) create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_converter_additional.py diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 5427b7b3a7a..4d58a9cf12f 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -129,6 +129,7 @@ class VariableTruncator(BaseTruncator): used_size += self.calculate_json_size(key) if used_size > budget: truncated_mapping[key] = "..." + is_truncated = True continue value_budget = (budget - used_size) // (length - len(truncated_mapping)) if isinstance(value, Segment): @@ -164,9 +165,9 @@ class VariableTruncator(BaseTruncator): result = self._truncate_segment(segment, self._max_size_bytes) if result.value_size > self._max_size_bytes: - if isinstance(result.value, str): - result = self._truncate_string(result.value, self._max_size_bytes) - return TruncationResult(StringSegment(value=result.value), True) + if isinstance(result.value, StringSegment): + fallback_result = self._truncate_string(result.value.value, self._max_size_bytes) + return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate json_str = dumps_with_segments(result.value, ensure_ascii=False) diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 1bd979b9ec2..acf5dff634c 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -85,3 +85,644 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations: assert len(custom_models) == 1 # The sanitizer should drop credentials in list response assert custom_models[0].credentials is None + + +# === Merged from test_model_provider_service.py === + + +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType + +from core.entities.model_entities import ModelStatus +from models.provider import ProviderType +from services import model_provider_service as service_module +from services.errors.app_model_config import ProviderNotFoundError +from services.model_provider_service import ModelProviderService + + +def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]: + manager = MagicMock() + service = ModelProviderService() + service._get_provider_manager = MagicMock(return_value=manager) + return service, manager + + +def _build_provider_configuration( + *, + provider_name: str = "openai", + supported_model_types: list[ModelType] | None = None, + custom_models: list[Any] | None = None, + custom_config_available: bool = True, +) -> SimpleNamespace: + if supported_model_types is None: + supported_model_types = [ModelType.LLM] + return SimpleNamespace( + provider=SimpleNamespace( + provider=provider_name, + label=I18nObject(en_US=provider_name), + description=None, + icon_small=None, + icon_small_dark=None, + background=None, + help=None, + supported_model_types=supported_model_types, + configurate_methods=[], + provider_credential_schema=None, + model_credential_schema=None, + ), + preferred_provider_type=ProviderType.CUSTOM, + custom_configuration=SimpleNamespace( + provider=SimpleNamespace( + current_credential_id="cred-1", + current_credential_name="Credential 1", + available_credentials=[], + ), + models=custom_models, + can_added_models=[], + ), + system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]), + is_custom_configuration_available=lambda: custom_config_available, + ) + + +def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + provider_configuration = SimpleNamespace(name="provider-config") + manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai") + + # Assert + assert result is provider_configuration + + +def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + manager.get_configurations.return_value = {} + + # Act / Assert + with pytest.raises(ProviderNotFoundError, match="does not exist"): + service._get_provider_configuration(tenant_id="tenant-1", provider="missing") + + +def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + allowed = _build_provider_configuration( + provider_name="openai", + supported_model_types=[ModelType.LLM], + custom_config_available=False, + ) + filtered = _build_provider_configuration( + provider_name="embedding", + supported_model_types=[ModelType.TEXT_EMBEDDING], + custom_config_available=True, + ) + manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered} + + # Act + result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value) + + # Assert + assert len(result) == 1 + assert result[0].provider == "openai" + assert result[0].custom_configuration.status.value == "no-configure" + + +def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + + class _Model: + def __init__(self, model_name: str) -> None: + self.model_name = model_name + + def model_dump(self) -> dict[str, Any]: + return { + "model": self.model_name, + "label": {"en_US": self.model_name}, + "model_type": ModelType.LLM, + "features": [], + "fetch_from": FetchFrom.PREDEFINED_MODEL, + "model_properties": {}, + "deprecated": False, + "status": ModelStatus.ACTIVE, + "load_balancing_enabled": False, + "has_invalid_load_balancing_configs": False, + "provider": { + "provider": "openai", + "label": {"en_US": "OpenAI"}, + "icon_small": None, + "icon_small_dark": None, + "supported_model_types": [ModelType.LLM], + }, + } + + provider_configurations = SimpleNamespace( + get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")]) + ) + manager.get_configurations.return_value = provider_configurations + + # Act + result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai") + + # Assert + assert len(result) == 2 + assert result[0].model == "gpt-4o" + assert result[1].provider.provider == "openai" + provider_configurations.get_models.assert_called_once_with(provider="openai") + + +@pytest.mark.parametrize( + ("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"), + [ + ( + "get_provider_credential", + {"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"}, + "get_provider_credential", + {"credential_id": "cred-1"}, + {"token": "abc"}, + ), + ( + "validate_provider_credentials", + {"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}}, + "validate_provider_credentials", + ({"token": "abc"},), + None, + ), + ( + "create_provider_credential", + {"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"}, + "create_provider_credential", + ({"token": "abc"}, "A"), + None, + ), + ( + "update_provider_credential", + { + "tenant_id": "tenant-1", + "provider": "openai", + "credentials": {"token": "abc"}, + "credential_id": "cred-1", + "credential_name": "B", + }, + "update_provider_credential", + {"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"}, + None, + ), + ( + "remove_provider_credential", + {"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"}, + "delete_provider_credential", + {"credential_id": "cred-1"}, + None, + ), + ( + "switch_active_provider_credential", + {"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"}, + "switch_active_provider_credential", + {"credential_id": "cred-1"}, + None, + ), + ], +) +def test_provider_credential_methods_should_delegate_to_provider_configuration( + method_name: str, + method_kwargs: dict[str, Any], + provider_method_name: str, + provider_call_kwargs: Any, + provider_return: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = ModelProviderService() + provider_configuration = MagicMock() + getattr(provider_configuration, provider_method_name).return_value = provider_return + get_provider_config_mock = MagicMock(return_value=provider_configuration) + monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock) + + # Act + result = getattr(service, method_name)(**method_kwargs) + + # Assert + get_provider_config_mock.assert_called_once_with("tenant-1", "openai") + provider_method = getattr(provider_configuration, provider_method_name) + if isinstance(provider_call_kwargs, tuple): + provider_method.assert_called_once_with(*provider_call_kwargs) + elif isinstance(provider_call_kwargs, dict): + provider_method.assert_called_once_with(**provider_call_kwargs) + else: + provider_method.assert_called_once_with(provider_call_kwargs) + if method_name == "get_provider_credential": + assert result == {"token": "abc"} + + +@pytest.mark.parametrize( + ("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"), + [ + ( + "get_model_credential", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credential_id": "cred-1", + }, + "get_custom_model_credential", + {"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"}, + {"api_key": "x"}, + ), + ( + "validate_model_credentials", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credentials": {"api_key": "x"}, + }, + "validate_custom_model_credentials", + {"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}}, + None, + ), + ( + "create_model_credential", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credentials": {"api_key": "x"}, + "credential_name": "cred-a", + }, + "create_custom_model_credential", + { + "model_type": ModelType.LLM, + "model": "gpt-4o", + "credentials": {"api_key": "x"}, + "credential_name": "cred-a", + }, + None, + ), + ( + "update_model_credential", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credentials": {"api_key": "x"}, + "credential_id": "cred-1", + "credential_name": "cred-b", + }, + "update_custom_model_credential", + { + "model_type": ModelType.LLM, + "model": "gpt-4o", + "credentials": {"api_key": "x"}, + "credential_id": "cred-1", + "credential_name": "cred-b", + }, + None, + ), + ( + "remove_model_credential", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credential_id": "cred-1", + }, + "delete_custom_model_credential", + {"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"}, + None, + ), + ( + "switch_active_custom_model_credential", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credential_id": "cred-1", + }, + "switch_custom_model_credential", + {"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"}, + None, + ), + ( + "add_model_credential_to_model_list", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + "credential_id": "cred-1", + }, + "add_model_credential_to_model", + {"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"}, + None, + ), + ( + "remove_model", + { + "tenant_id": "tenant-1", + "provider": "openai", + "model_type": ModelType.LLM.value, + "model": "gpt-4o", + }, + "delete_custom_model", + {"model_type": ModelType.LLM, "model": "gpt-4o"}, + None, + ), + ], +) +def test_custom_model_methods_should_convert_model_type_and_delegate( + method_name: str, + method_kwargs: dict[str, Any], + provider_method_name: str, + expected_kwargs: dict[str, Any], + provider_return: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = ModelProviderService() + provider_configuration = MagicMock() + getattr(provider_configuration, provider_method_name).return_value = provider_return + get_provider_config_mock = MagicMock(return_value=provider_configuration) + monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock) + + # Act + result = getattr(service, method_name)(**method_kwargs) + + # Assert + get_provider_config_mock.assert_called_once_with("tenant-1", "openai") + getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs) + if method_name == "get_model_credential": + assert result == {"api_key": "x"} + + +def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + openai_provider = SimpleNamespace( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=None, + icon_small_dark=None, + ) + anthropic_provider = SimpleNamespace( + provider="anthropic", + label=I18nObject(en_US="Anthropic"), + icon_small=None, + icon_small_dark=None, + ) + models = [ + SimpleNamespace( + provider=openai_provider, + model="gpt-4o", + label=I18nObject(en_US="GPT-4o"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + status=ModelStatus.ACTIVE, + load_balancing_enabled=False, + deprecated=False, + ), + SimpleNamespace( + provider=openai_provider, + model="old-openai", + label=I18nObject(en_US="Old OpenAI"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + status=ModelStatus.ACTIVE, + load_balancing_enabled=False, + deprecated=True, + ), + SimpleNamespace( + provider=anthropic_provider, + model="old-anthropic", + label=I18nObject(en_US="Old Anthropic"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + status=ModelStatus.ACTIVE, + load_balancing_enabled=False, + deprecated=True, + ), + ] + provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models)) + manager.get_configurations.return_value = provider_configurations + + # Act + result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value) + + # Assert + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + assert len(result) == 1 + assert result[0].provider == "openai" + assert len(result[0].models) == 1 + assert result[0].models[0].model == "gpt-4o" + + +@pytest.mark.parametrize( + ("credentials", "schema", "expected_count"), + [ + (None, None, 0), + ({"api_key": "x"}, None, 0), + ( + {"api_key": "x"}, + SimpleNamespace( + parameter_rules=[ + ParameterRule( + name="temperature", + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + ) + ] + ), + 1, + ), + ], +) +def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema( + credentials: dict[str, Any] | None, + schema: Any, + expected_count: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = ModelProviderService() + provider_configuration = MagicMock() + provider_configuration.get_current_credentials.return_value = credentials + provider_configuration.get_model_schema.return_value = schema + monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration)) + + # Act + result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o") + + # Assert + assert len(result) == expected_count + provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o") + if credentials: + provider_configuration.get_model_schema.assert_called_once_with( + model_type=ModelType.LLM, + model="gpt-4o", + credentials=credentials, + ) + else: + provider_configuration.get_model_schema.assert_not_called() + + +def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + manager.get_default_model.return_value = SimpleNamespace( + model="gpt-4o", + model_type=ModelType.LLM, + provider=SimpleNamespace( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=None, + supported_model_types=[ModelType.LLM], + ), + ) + + # Act + result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value) + + # Assert + assert result is not None + assert result.model == "gpt-4o" + assert result.provider.provider == "openai" + manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM) + + +def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + manager.get_default_model.return_value = None + + # Act + result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value) + + # Assert + assert result is None + + +def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + manager.get_default_model.side_effect = RuntimeError("boom") + + # Act + result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value) + + # Assert + assert result is None + + +def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None: + # Arrange + service, manager = _create_service_with_mocked_manager() + + # Act + service.update_default_model_of_model_type( + tenant_id="tenant-1", + model_type=ModelType.LLM.value, + provider="openai", + model="gpt-4o", + ) + + # Assert + manager.update_default_model_record.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.LLM, + provider="openai", + model="gpt-4o", + ) + + +def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + service = ModelProviderService() + factory_instance = MagicMock() + factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png") + factory_constructor = MagicMock(return_value=factory_instance) + monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor) + + # Act + result = service.get_model_provider_icon( + tenant_id="tenant-1", + provider="openai", + icon_type="icon_small", + lang="en_US", + ) + + # Assert + factory_constructor.assert_called_once_with(tenant_id="tenant-1") + factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") + assert result == (b"icon-bytes", "image/png") + + +def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + service = ModelProviderService() + provider_configuration = MagicMock() + monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration)) + + # Act + service.switch_preferred_provider( + tenant_id="tenant-1", + provider="openai", + preferred_provider_type=ProviderType.SYSTEM.value, + ) + + # Assert + provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM) + + +@pytest.mark.parametrize( + ("method_name", "provider_method_name"), + [ + ("enable_model", "enable_model"), + ("disable_model", "disable_model"), + ], +) +def test_model_enablement_methods_should_convert_model_type_and_delegate( + method_name: str, + provider_method_name: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = ModelProviderService() + provider_configuration = MagicMock() + monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration)) + + # Act + getattr(service, method_name)( + tenant_id="tenant-1", + provider="openai", + model="gpt-4o", + model_type=ModelType.LLM.value, + ) + + # Assert + getattr(provider_configuration, provider_method_name).assert_called_once_with( + model="gpt-4o", + model_type=ModelType.LLM, + ) diff --git a/api/tests/unit_tests/services/test_recommended_app_service.py b/api/tests/unit_tests/services/test_recommended_app_service.py index 12f4c0b982f..12bc84db871 100644 --- a/api/tests/unit_tests/services/test_recommended_app_service.py +++ b/api/tests/unit_tests/services/test_recommended_app_service.py @@ -316,7 +316,7 @@ class TestRecommendedAppServiceGetDetail: mock_factory_class.get_recommend_app_factory.return_value = mock_factory # Act - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) # Assert assert result == expected_detail @@ -346,7 +346,7 @@ class TestRecommendedAppServiceGetDetail: mock_factory_class.get_recommend_app_factory.return_value = mock_factory # Act - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) # Assert assert result["name"] == f"App from {mode}" @@ -369,7 +369,7 @@ class TestRecommendedAppServiceGetDetail: mock_factory_class.get_recommend_app_factory.return_value = mock_factory # Act - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) # Assert assert result is None @@ -392,7 +392,7 @@ class TestRecommendedAppServiceGetDetail: mock_factory_class.get_recommend_app_factory.return_value = mock_factory # Act - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) # Assert assert result == {} @@ -432,9 +432,197 @@ class TestRecommendedAppServiceGetDetail: mock_factory_class.get_recommend_app_factory.return_value = mock_factory # Act - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id)) # Assert assert result["model_config"] == complex_model_config assert len(result["workflows"]) == 2 assert len(result["tools"]) == 3 + + +# === Merged from test_recommended_app_service_additional.py === + + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest + +from services import recommended_app_service as service_module +from services.recommended_app_service import RecommendedAppService + + +def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]: + return cast(dict[str, Any], result) + + +@pytest.fixture +def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + # Arrange + session = MagicMock() + monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session)) + + # Assert + return session + + +def _mock_factory_for_apps( + monkeypatch: pytest.MonkeyPatch, + *, + mode: str, + result: dict[str, Any], + fallback_result: dict[str, Any] | None = None, +) -> tuple[MagicMock, MagicMock]: + retrieval_instance = MagicMock() + retrieval_instance.get_recommended_apps_and_categories.return_value = result + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + + builtin_instance = MagicMock() + if fallback_result is not None: + builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_buildin_recommend_app_retrieval", + MagicMock(return_value=builtin_instance), + ) + return retrieval_instance, builtin_instance + + +def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled( + monkeypatch: pytest.MonkeyPatch, + mocked_db_session: MagicMock, +) -> None: + # Arrange + expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]} + retrieval_instance, builtin_instance = _mock_factory_for_apps( + monkeypatch, + mode="remote", + result=expected, + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), + ) + + # Act + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + # Assert + assert result == expected + retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called() + mocked_db_session.scalar.assert_not_called() + + +def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled( + monkeypatch: pytest.MonkeyPatch, + mocked_db_session: MagicMock, +) -> None: + # Arrange + remote_result = {"recommended_apps": [], "categories": []} + fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]} + _, builtin_instance = _mock_factory_for_apps( + monkeypatch, + mode="remote", + result=remote_result, + fallback_result=fallback_result, + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None] + + # Act + result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP") + + # Assert + builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + assert result["recommended_apps"][0]["can_trial"] is True + assert result["recommended_apps"][1]["can_trial"] is False + assert mocked_db_session.scalar.call_count == 2 + + +@pytest.mark.parametrize( + ("trial_query_result", "expected_can_trial"), + [ + (SimpleNamespace(id="trial"), True), + (None, False), + ], +) +def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled( + monkeypatch: pytest.MonkeyPatch, + mocked_db_session: MagicMock, + trial_query_result: Any, + expected_can_trial: bool, +) -> None: + # Arrange + detail = {"id": "app-1", "name": "Test App"} + retrieval_instance = MagicMock() + retrieval_instance.get_recommend_app_detail.return_value = detail + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + mocked_db_session.scalar.return_value = trial_query_result + + # Act + result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1")) + + # Assert + assert result["id"] == "app-1" + assert result["can_trial"] is expected_can_trial + mocked_db_session.scalar.assert_called_once() + + +def test_add_trial_app_record_should_increment_count_when_existing_record_found( + mocked_db_session: MagicMock, +) -> None: + # Arrange + existing_record = SimpleNamespace(count=3) + mocked_db_session.scalar.return_value = existing_record + + # Act + RecommendedAppService.add_trial_app_record("app-1", "account-1") + + # Assert + assert existing_record.count == 4 + mocked_db_session.scalar.assert_called_once() + mocked_db_session.commit.assert_called_once() + mocked_db_session.add.assert_not_called() + + +def test_add_trial_app_record_should_create_new_record_when_no_existing_record( + mocked_db_session: MagicMock, +) -> None: + # Arrange + mocked_db_session.scalar.return_value = None + + # Act + RecommendedAppService.add_trial_app_record("app-2", "account-2") + + # Assert + mocked_db_session.scalar.assert_called_once() + mocked_db_session.add.assert_called_once() + added = mocked_db_session.add.call_args.args[0] + assert added.app_id == "app-2" + assert added.account_id == "account-2" + assert added.count == 1 + mocked_db_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_schedule_service.py b/api/tests/unit_tests/services/test_schedule_service.py index e28965ea2c3..2a78876da61 100644 --- a/api/tests/unit_tests/services/test_schedule_service.py +++ b/api/tests/unit_tests/services/test_schedule_service.py @@ -1,12 +1,15 @@ import unittest from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any, cast from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig -from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError +from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError from events.event_handlers.sync_workflow_schedule_when_app_published import ( sync_schedule_from_workflow, ) @@ -14,6 +17,8 @@ from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan from models.workflow import Workflow +from services.errors.account import AccountNotFoundError +from services.trigger import schedule_service as service_module from services.trigger.schedule_service import ScheduleService @@ -775,5 +780,158 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase): mock_session.commit.assert_called_once() +@pytest.fixture +def session_mock() -> MagicMock: + return MagicMock(spec=Session) + + +def _workflow(**kwargs: Any) -> Workflow: + return cast(Workflow, SimpleNamespace(**kwargs)) + + +def test_update_schedule_should_update_only_node_id_without_recomputing_time( + session_mock: MagicMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + schedule = MagicMock(spec=WorkflowSchedulePlan) + schedule.cron_expression = "0 10 * * *" + schedule.timezone = "UTC" + session_mock.get.return_value = schedule + + next_run_mock = MagicMock(return_value=datetime(2026, 1, 1, 10, 0, tzinfo=UTC)) + monkeypatch.setattr(service_module, "calculate_next_run_at", next_run_mock) + + # Act + result = ScheduleService.update_schedule( + session=session_mock, + schedule_id="schedule-1", + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + # Assert + assert result is schedule + assert schedule.node_id == "node-new" + next_run_mock.assert_not_called() + session_mock.flush.assert_called_once() + + +def test_get_tenant_owner_should_raise_when_account_record_missing(session_mock: MagicMock) -> None: + # Arrange + join = SimpleNamespace(account_id="account-404") + session_mock.execute.return_value.scalar_one_or_none.return_value = join + session_mock.get.return_value = None + + # Act / Assert + with pytest.raises(AccountNotFoundError, match="Account not found: account-404"): + ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1") + + +def test_get_tenant_owner_should_raise_when_no_owner_or_admin_found(session_mock: MagicMock) -> None: + # Arrange + session_mock.execute.return_value.scalar_one_or_none.side_effect = [None, None] + + # Act / Assert + with pytest.raises(AccountNotFoundError, match="Account not found for tenant: tenant-1"): + ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1") + + +def test_update_next_run_at_should_raise_when_schedule_not_found(session_mock: MagicMock) -> None: + # Arrange + session_mock.get.return_value = None + + # Act / Assert + with pytest.raises(ScheduleNotFoundError, match="Schedule not found: schedule-1"): + ScheduleService.update_next_run_at(session=session_mock, schedule_id="schedule-1") + + +def test_to_schedule_config_should_build_from_cron_mode() -> None: + # Arrange + node_config: dict[str, Any] = { + "id": "node-1", + "data": { + "mode": "cron", + "cron_expression": "0 12 * * *", + "timezone": "Asia/Kolkata", + }, + } + + # Act + result = ScheduleService.to_schedule_config(node_config=node_config) + + # Assert + assert result.node_id == "node-1" + assert result.cron_expression == "0 12 * * *" + assert result.timezone == "Asia/Kolkata" + + +def test_to_schedule_config_should_raise_for_cron_mode_without_expression() -> None: + # Arrange + node_config = {"id": "node-1", "data": {"mode": "cron", "cron_expression": ""}} + + # Act / Assert + with pytest.raises(ScheduleConfigError, match="Cron expression is required for cron mode"): + ScheduleService.to_schedule_config(node_config=node_config) + + +def test_to_schedule_config_should_build_from_visual_mode(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + node_config = { + "id": "node-1", + "data": { + "mode": "visual", + "frequency": "daily", + "visual_config": {"time": "9:30 AM"}, + "timezone": "UTC", + }, + } + monkeypatch.setattr(ScheduleService, "visual_to_cron", MagicMock(return_value="30 9 * * *")) + + # Act + result = ScheduleService.to_schedule_config(node_config=node_config) + + # Assert + assert result.cron_expression == "30 9 * * *" + + +def test_to_schedule_config_should_raise_for_invalid_mode() -> None: + # Arrange + node_config = {"id": "node-1", "data": {"mode": "manual"}} + + # Act / Assert + with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: manual"): + ScheduleService.to_schedule_config(node_config=node_config) + + +def test_extract_schedule_config_should_raise_when_graph_is_empty() -> None: + # Arrange + workflow = _workflow(graph_dict={}) + + # Act / Assert + with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"): + ScheduleService.extract_schedule_config(workflow=workflow) + + +def test_extract_schedule_config_should_raise_when_mode_invalid() -> None: + # Arrange + workflow = _workflow( + graph_dict={ + "nodes": [ + { + "id": "schedule-1", + "data": { + "type": TRIGGER_SCHEDULE_NODE_TYPE, + "mode": "invalid", + }, + } + ] + } + ) + + # Act / Assert + with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: invalid"): + ScheduleService.extract_schedule_config(workflow=workflow) + + if __name__ == "__main__": unittest.main() diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 9c231352256..27602bb1cc2 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -12,6 +12,7 @@ This test suite covers all functionality of the current VariableTruncator includ import functools import json import uuid +from collections.abc import Mapping from typing import Any from uuid import uuid4 @@ -199,14 +200,14 @@ class TestArrayTruncation: def test_small_array_no_truncation(self, small_truncator: VariableTruncator): """Test that small arrays are not truncated.""" - small_array = [1, 2] + small_array: list[object] = [1, 2] result = small_truncator._truncate_array(small_array, 1000) assert result.value == small_array assert result.truncated is False def test_array_element_limit_truncation(self, small_truncator: VariableTruncator): """Test that arrays over element limit are truncated.""" - large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3 + large_array: list[object] = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3 result = small_truncator._truncate_array(large_array, 1000) assert result.truncated is True @@ -215,7 +216,7 @@ class TestArrayTruncation: def test_array_size_budget_truncation(self, small_truncator: VariableTruncator): """Test array truncation due to size budget constraints.""" # Create array with strings that will exceed size budget - large_strings = ["very long string " * 5, "another long string " * 5] + large_strings: list[object] = ["very long string " * 5, "another long string " * 5] result = small_truncator._truncate_array(large_strings, 50) assert result.truncated is True @@ -276,10 +277,10 @@ class TestObjectTruncation: # Values should be truncated if they exist for key, value in result.value.items(): - if isinstance(value, str): - original_value = obj_with_long_values[key] - # Value should be same or smaller - assert len(value) <= len(original_value) + assert isinstance(value, str) + original_value = obj_with_long_values[key] + # Value should be same or smaller + assert len(value) <= len(original_value) def test_object_key_dropping(self, small_truncator): """Test object truncation where keys are dropped due to size constraints.""" @@ -506,10 +507,9 @@ class TestEdgeCases: truncator = VariableTruncator(string_length_limit=10) # Unicode characters - unicode_text = "πŸŒπŸš€πŸŒπŸš€πŸŒπŸš€πŸŒπŸš€πŸŒπŸš€" # Each emoji counts as 1 character + unicode_text = "δ½ ε₯½δΈ–η•Œδ½ ε₯½δΈ–η•Œδ½ ε₯½δΈ–η•Œ" # Multi-byte UTF-8 characters result = truncator.truncate(StringSegment(value=unicode_text)) - if len(unicode_text) > 10: - assert result.truncated is True + assert result.truncated is True # Special JSON characters special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}' @@ -631,13 +631,12 @@ class TestIntegrationScenarios: result = truncator.truncate(segment) assert isinstance(result, TruncationResult) - # Should handle all data types appropriately - if result.truncated: - # Verify the result is smaller or equal than original - original_size = truncator.calculate_json_size(mixed_data) - if isinstance(result.result, ObjectSegment): - result_size = truncator.calculate_json_size(result.result.value) - assert result_size <= original_size + assert result.truncated is True + assert isinstance(result.result, ObjectSegment) + # Verify the result is smaller or equal than original + original_size = truncator.calculate_json_size(mixed_data) + result_size = truncator.calculate_json_size(result.result.value) + assert result_size <= original_size def test_file_and_array_file_variable_mapping(self, file): truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300) @@ -675,3 +674,229 @@ def test_dummy_variable_truncator_methods(): assert isinstance(result, TruncationResult) assert result.result == segment assert result.truncated is False + + +# === Merged from test_variable_truncator_additional.py === + + +from typing import Any + +import pytest +from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable +from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment +from graphon.variables.types import SegmentType + +from services import variable_truncator as truncator_module +from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator + + +class _AbstractPassthrough(BaseTruncator): + def truncate(self, segment: Any) -> TruncationResult: + # Arrange / Act + return super().truncate(segment) # type: ignore[misc] + + def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]: + # Arrange / Act + return super().truncate_variable_mapping(v) # type: ignore[misc] + + +def test_base_truncator_methods_should_execute_abstract_placeholders() -> None: + # Arrange + passthrough = _AbstractPassthrough() + + # Act + truncate_result = passthrough.truncate(StringSegment(value="x")) + mapping_result = passthrough.truncate_variable_mapping({"a": 1}) + + # Assert + assert truncate_result is None + assert mapping_result is None + + +def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111) + monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7) + monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33) + + # Act + truncator = VariableTruncator.default() + + # Assert + assert truncator._max_size_bytes == 111 + assert truncator._array_element_limit == 7 + assert truncator._string_length_limit == 33 + + +def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None: + # Arrange + truncator = VariableTruncator(max_size_bytes=5) + mapping = {"very_long_key": "value"} + + # Act + result, truncated = truncator.truncate_variable_mapping(mapping) + + # Assert + assert result == {"very_long_key": "..."} + assert truncated is True + + +def test_truncate_variable_mapping_should_handle_segment_values() -> None: + # Arrange + truncator = VariableTruncator(max_size_bytes=100) + mapping = {"seg": StringSegment(value="hello")} + + # Act + result, truncated = truncator.truncate_variable_mapping(mapping) + + # Assert + assert isinstance(result["seg"], StringSegment) + assert result["seg"].value == "hello" + assert truncated is False + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + (True, False), + (1, False), + (1.5, False), + ("x", True), + ({"k": "v"}, True), + ], +) +def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None: + # Arrange + + # Act + result = VariableTruncator._json_value_needs_truncation(value) + + # Assert + assert result is expected + + +def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + truncator = VariableTruncator(max_size_bytes=10) + forced_result = truncator_module._PartResult( + value=StringSegment(value="this is too long"), + value_size=100, + truncated=True, + ) + monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result) + + # Act + result = truncator.truncate(StringSegment(value="input")) + + # Assert + assert result.truncated is True + assert isinstance(result.result, StringSegment) + assert not result.result.value.startswith('"') + + +def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + truncator = VariableTruncator() + monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True) + + # Act / Assert + with pytest.raises(AssertionError): + truncator._truncate_segment(IntegerSegment(value=1), 10) + + +def test_calculate_json_size_should_unwrap_segment_values() -> None: + # Arrange + segment = StringSegment(value="abc") + + # Act + size = VariableTruncator.calculate_json_size(segment) + + # Assert + assert size == VariableTruncator.calculate_json_size("abc") + + +def test_calculate_json_size_should_handle_updated_variable_instances() -> None: + # Arrange + updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v") + + # Act + size = VariableTruncator.calculate_json_size(updated) + + # Assert + assert size > 0 + + +def test_maybe_qa_structure_should_validate_shape() -> None: + # Arrange + + # Act / Assert + assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True + assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False + assert VariableTruncator._maybe_qa_structure({}) is False + + +def test_maybe_parent_child_structure_should_validate_shape() -> None: + # Arrange + + # Act / Assert + assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True + assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False + assert ( + VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False + ) + + +def test_truncate_object_should_truncate_segment_values_inside_object() -> None: + # Arrange + truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30) + mapping = {"s": StringSegment(value="long-content")} + + # Act + result = truncator._truncate_object(mapping, 20) + + # Assert + assert result.truncated is True + assert isinstance(result.value["s"], StringSegment) + + +def test_truncate_json_primitives_should_handle_updated_variable_input() -> None: + # Arrange + truncator = VariableTruncator(max_size_bytes=100) + updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v") + + # Act + result = truncator._truncate_json_primitives(updated, 100) + + # Assert + assert isinstance(result.value, dict) + + +def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None: + # Arrange + truncator = VariableTruncator() + + # Act / Assert + with pytest.raises(AssertionError): + truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type] + + +def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + truncator = VariableTruncator(max_size_bytes=10) + forced_segment = ObjectSegment(value={"k": "v"}) + forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True) + monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result) + + # Act + result = truncator.truncate(ObjectSegment(value={"a": "b"})) + + # Assert + assert result.truncated is True + assert isinstance(result.result, StringSegment) diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index ffdcc046f98..78049182ad7 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -559,3 +559,757 @@ class TestWebhookServiceUnit: result = _prepare_webhook_execution("test_webhook", is_debug=True) assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None) + + +# === Merged from test_webhook_service_additional.py === + + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from graphon.variables.types import SegmentType +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import RequestEntityTooLarge + +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + WebhookBodyParameter, + WebhookData, + WebhookParameter, +) +from models.enums import AppTriggerStatus +from models.model import App +from models.trigger import WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger import webhook_service as service_module +from services.trigger.webhook_service import WebhookService + + +class _FakeQuery: + def __init__(self, result: Any) -> None: + self._result = result + + def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery": + return self + + def first(self) -> Any: + return self._result + + +class _SessionContext: + def __init__(self, session: Any) -> None: + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +@pytest.fixture +def flask_app() -> Flask: + return Flask(__name__) + + +def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock())) + monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session)) + + +def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger: + return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs)) + + +def _workflow(**kwargs: Any) -> Workflow: + return cast(Workflow, SimpleNamespace(**kwargs)) + + +def _app(**kwargs: Any) -> App: + return cast(App, SimpleNamespace(**kwargs)) + + +def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + fake_session = MagicMock() + fake_session.query.return_value = _FakeQuery(None) + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="Webhook not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + fake_session = MagicMock() + fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) + fake_session = MagicMock() + fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) + fake_session = MagicMock() + fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) + fake_session = MagicMock() + fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)] + _patch_session(monkeypatch, fake_session) + + # Act / Assert + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + +def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} + + fake_session = MagicMock() + fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)] + _patch_session(monkeypatch, fake_session) + + # Act + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1") + + # Assert + assert got_trigger is webhook_trigger + assert got_workflow is workflow + assert got_node_config == {"data": {"key": "value"}} + + +def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} + + fake_session = MagicMock() + fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)] + _patch_session(monkeypatch, fake_session) + + # Act + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + "webhook-1", is_debug=True + ) + + # Assert + assert got_trigger is webhook_trigger + assert got_workflow is workflow + assert got_node_config == {"data": {"mode": "debug"}} + + +def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + warning_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "warning", warning_mock) + webhook_trigger = MagicMock() + + # Act + with flask_app.test_request_context( + "/webhook", + method="POST", + headers={"Content-Type": "application/vnd.custom"}, + data="plain content", + ): + result = WebhookService.extract_webhook_data(webhook_trigger) + + # Assert + assert result["body"] == {"raw": "plain content"} + warning_mock.assert_called_once() + + +def test_extract_webhook_data_should_raise_for_request_too_large( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1) + + # Act / Assert + with flask_app.test_request_context("/webhook", method="POST", data="ab"): + with pytest.raises(RequestEntityTooLarge): + WebhookService.extract_webhook_data(MagicMock()) + + +def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None: + # Arrange + webhook_trigger = MagicMock() + + # Act + with flask_app.test_request_context("/webhook", method="POST", data=b""): + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + # Assert + assert body == {"raw": None} + assert files == {} + + +def test_extract_octet_stream_body_should_return_none_when_processing_raises( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = MagicMock() + monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")) + monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom"))) + + # Act + with flask_app.test_request_context("/webhook", method="POST", data=b"abc"): + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + # Assert + assert body == {"raw": None} + assert files == {} + + +def test_extract_text_body_should_return_empty_string_when_request_read_fails( + flask_app: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error"))) + + # Act + with flask_app.test_request_context("/webhook", method="POST", data="abc"): + body, files = WebhookService._extract_text_body() + + # Assert + assert body == {"raw": ""} + assert files == {} + + +def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = RuntimeError("magic failed") + monkeypatch.setattr(service_module, "magic", fake_magic) + + # Act + result = WebhookService._detect_binary_mimetype(b"binary") + + # Assert + assert result == "application/octet-stream" + + +def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1") + file_obj = MagicMock() + file_obj.to_dict.return_value = {"id": "f-1"} + monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj)) + monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None))) + + uploaded = MagicMock() + uploaded.filename = "file.unknown" + uploaded.content_type = None + uploaded.read.return_value = b"content" + + # Act + result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger) + + # Assert + assert result == {"f": {"id": "f-1"}} + + +def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1") + manager = MagicMock() + manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1") + monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager)) + expected_file = MagicMock() + monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file)) + + # Act + result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger) + + # Assert + assert result is expected_file + manager.create_file_by_raw.assert_called_once() + + +@pytest.mark.parametrize( + ("raw_value", "param_type", "expected"), + [ + ("42", SegmentType.NUMBER, 42), + ("3.14", SegmentType.NUMBER, 3.14), + ("yes", SegmentType.BOOLEAN, True), + ("no", SegmentType.BOOLEAN, False), + ], +) +def test_convert_form_value_should_convert_supported_types( + raw_value: str, + param_type: str, + expected: Any, +) -> None: + # Arrange + + # Act + result = WebhookService._convert_form_value("param", raw_value, param_type) + + # Assert + assert result == expected + + +def test_convert_form_value_should_raise_for_unsupported_type() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="Unsupported type"): + WebhookService._convert_form_value("p", "x", SegmentType.FILE) + + +def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + warning_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "warning", warning_mock) + + # Act + result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type") + + # Assert + assert result == {"x": 1} + warning_mock.assert_called_once() + + +def test_validate_and_convert_value_should_wrap_conversion_errors() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="validation failed"): + WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True) + + +def test_process_parameters_should_raise_when_required_parameter_missing() -> None: + # Arrange + raw_params = {"optional": "x"} + config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)] + + # Act / Assert + with pytest.raises(ValueError, match="Required parameter missing"): + WebhookService._process_parameters(raw_params, config, is_form_data=True) + + +def test_process_parameters_should_include_unconfigured_parameters() -> None: + # Arrange + raw_params = {"known": "1", "unknown": "x"} + config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)] + + # Act + result = WebhookService._process_parameters(raw_params, config, is_form_data=True) + + # Assert + assert result == {"known": 1, "unknown": "x"} + + +def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None: + # Arrange + + # Act / Assert + with pytest.raises(ValueError, match="Required body content missing"): + WebhookService._process_body_parameters( + raw_body={"raw": ""}, + body_configs=[WebhookBodyParameter(name="raw", required=True)], + content_type=ContentType.TEXT, + ) + + +def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None: + # Arrange + raw_body = {"message": "hello", "extra": "x"} + body_configs = [ + WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True), + WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True), + ] + + # Act + result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA) + + # Assert + assert result == {"message": "hello", "extra": "x"} + + +def test_validate_required_headers_should_accept_sanitized_header_names() -> None: + # Arrange + headers = {"x_api_key": "123"} + configs = [WebhookParameter(name="x-api-key", required=True)] + + # Act + WebhookService._validate_required_headers(headers, configs) + + # Assert + assert True + + +def test_validate_required_headers_should_raise_when_required_header_missing() -> None: + # Arrange + headers = {"x-other": "123"} + configs = [WebhookParameter(name="x-api-key", required=True)] + + # Act / Assert + with pytest.raises(ValueError, match="Required header missing"): + WebhookService._validate_required_headers(headers, configs) + + +def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None: + # Arrange + webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}} + node_data = WebhookData(method="post", content_type=ContentType.TEXT) + + # Act + result = WebhookService._validate_http_metadata(webhook_data, node_data) + + # Assert + assert result["valid"] is False + assert "Content-type mismatch" in result["error"] + + +def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None: + # Arrange + headers = {"content-type": "application/json; charset=utf-8"} + + # Act + result = WebhookService._extract_content_type(headers) + + # Assert + assert result == "application/json" + + +def test_build_workflow_inputs_should_include_expected_keys() -> None: + # Arrange + webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}} + + # Act + result = WebhookService.build_workflow_inputs(webhook_data) + + # Assert + assert result["webhook_data"] == webhook_data + assert result["webhook_headers"] == {"h": "v"} + assert result["webhook_query_params"] == {"q": 1} + assert result["webhook_body"] == {"b": 2} + + +def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + webhook_data = {"body": {"x": 1}} + + session = MagicMock() + _patch_session(monkeypatch, session) + + end_user = SimpleNamespace(id="end-user-1") + monkeypatch.setattr( + service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user) + ) + quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock())) + monkeypatch.setattr(service_module, "QuotaType", quota_type) + trigger_async_mock = MagicMock() + monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock) + + # Act + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + # Assert + trigger_async_mock.assert_called_once() + + +def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + + session = MagicMock() + _patch_session(monkeypatch, session) + + monkeypatch.setattr( + service_module.EndUserService, + "get_or_create_end_user_by_type", + MagicMock(return_value=SimpleNamespace(id="end-user-1")), + ) + quota_type = SimpleNamespace( + TRIGGER=SimpleNamespace( + consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1)) + ) + ) + monkeypatch.setattr(service_module, "QuotaType", quota_type) + mark_rate_limited_mock = MagicMock() + monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock) + + # Act / Assert + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow) + mark_rate_limited_mock.assert_called_once_with("tenant-1") + + +def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + webhook_trigger = _workflow_trigger( + app_id="app-1", + node_id="node-1", + tenant_id="tenant-1", + webhook_id="webhook-1", + ) + workflow = _workflow(id="wf-1") + + session = MagicMock() + _patch_session(monkeypatch, session) + + monkeypatch.setattr( + service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom")) + ) + logger_exception_mock = MagicMock() + monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock) + + # Act / Assert + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow) + logger_exception_mock.assert_called_once() + + +def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow( + walk_nodes=lambda _node_type: [ + (f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1) + ] + ) + + # Act / Assert + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + +def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})]) + + lock = MagicMock() + lock.acquire.return_value = False + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + + # Act / Assert + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + +def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})]) + + class _WorkflowWebhookTrigger: + app_id = "app_id" + tenant_id = "tenant_id" + webhook_id = "webhook_id" + node_id = "node_id" + + def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None: + self.id = None + self.app_id = app_id + self.tenant_id = tenant_id + self.node_id = node_id + self.webhook_id = webhook_id + self.created_by = created_by + + class _Select: + def where(self, *args: Any, **kwargs: Any) -> "_Select": + return self + + class _Session: + def __init__(self) -> None: + self.added: list[Any] = [] + self.deleted: list[Any] = [] + self.commit_count = 0 + self.existing_records = [SimpleNamespace(node_id="node-stale")] + + def scalars(self, _stmt: Any) -> Any: + return SimpleNamespace(all=lambda: self.existing_records) + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def flush(self) -> None: + for idx, obj in enumerate(self.added, start=1): + if obj.id is None: + obj.id = f"rec-{idx}" + + def commit(self) -> None: + self.commit_count += 1 + + def delete(self, obj: Any) -> None: + self.deleted.append(obj) + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.return_value = None + + fake_session = _Session() + + monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger) + monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select())) + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + redis_set_mock = MagicMock() + redis_delete_mock = MagicMock() + monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock) + monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock) + monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id")) + _patch_session(monkeypatch, fake_session) + + # Act + WebhookService.sync_webhook_relationships(app, workflow) + + # Assert + assert len(fake_session.added) == 1 + assert len(fake_session.deleted) == 1 + assert fake_session.commit_count == 2 + redis_set_mock.assert_called_once() + redis_delete_mock.assert_called_once() + lock.release.assert_called_once() + + +def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1") + workflow = _workflow(walk_nodes=lambda _node_type: []) + + class _Select: + def where(self, *args: Any, **kwargs: Any) -> "_Select": + return self + + class _Session: + def scalars(self, _stmt: Any) -> Any: + return SimpleNamespace(all=lambda: []) + + def commit(self) -> None: + return None + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + logger_exception_mock = MagicMock() + + monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select())) + monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock) + _patch_session(monkeypatch, _Session()) + + # Act + WebhookService.sync_webhook_relationships(app, workflow) + + # Assert + assert logger_exception_mock.call_count == 1 + + +def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None: + # Arrange + node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}} + + # Act + body, status = WebhookService.generate_webhook_response(node_config) + + # Assert + assert status == 200 + assert "message" in body + + +def test_generate_webhook_id_should_return_24_character_identifier() -> None: + # Arrange + + # Act + webhook_id = WebhookService.generate_webhook_id() + + # Assert + assert isinstance(webhook_id, str) + assert len(webhook_id) == 24 + + +def test_sanitize_key_should_return_original_value_for_non_string_input() -> None: + # Arrange + + # Act + result = WebhookService._sanitize_key(123) # type: ignore[arg-type] + + # Assert + assert result == 123 diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index a62c9f45556..64b21317abb 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -176,3 +176,300 @@ class TestWorkflowRunService: service = WorkflowRunService(session_factory) assert service._session_factory == session_factory + + +# === Merged from test_workflow_run_service.py === + + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest + +from models import Account, App, EndUser, WorkflowRunTriggeredFrom +from services import workflow_run_service as service_module +from services.workflow_run_service import WorkflowRunService + + +@pytest.fixture +def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]: + # Arrange + node_repo = MagicMock() + workflow_run_repo = MagicMock() + factory = SimpleNamespace( + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + + # Assert + return node_repo, workflow_run_repo, factory + + +def _app_model(**kwargs: Any) -> App: + return cast(App, SimpleNamespace(**kwargs)) + + +def _account(**kwargs: Any) -> Account: + return cast(Account, SimpleNamespace(**kwargs)) + + +def _end_user(**kwargs: Any) -> EndUser: + return cast(EndUser, SimpleNamespace(**kwargs)) + + +def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing( + monkeypatch: pytest.MonkeyPatch, + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], +) -> None: + # Arrange + session_factory = MagicMock(name="session_factory") + sessionmaker_mock = MagicMock(return_value=session_factory) + monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock) + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine")) + + # Act + service = WorkflowRunService() + + # Assert + sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False) + assert service._session_factory is session_factory + + +def test___init___should_create_sessionmaker_when_engine_is_provided( + monkeypatch: pytest.MonkeyPatch, + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], +) -> None: + # Arrange + class FakeEngine: + pass + + session_factory = MagicMock(name="session_factory") + sessionmaker_mock = MagicMock(return_value=session_factory) + monkeypatch.setattr(service_module, "Engine", FakeEngine) + monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock) + engine = cast(Engine, FakeEngine()) + + # Act + service = WorkflowRunService(session_factory=engine) + + # Assert + sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False) + assert service._session_factory is session_factory + + +def test___init___should_keep_provided_sessionmaker_and_create_repositories( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], +) -> None: + # Arrange + node_repo, workflow_run_repo, factory = repository_factory_mocks + session_factory = MagicMock(name="session_factory") + + # Act + service = WorkflowRunService(session_factory=session_factory) + + # Assert + assert service._session_factory is session_factory + assert service._node_execution_service_repo is node_repo + assert service._workflow_run_repo is workflow_run_repo + factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory) + factory.create_api_workflow_run_repository.assert_called_once_with(session_factory) + + +def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], +) -> None: + # Arrange + _, workflow_run_repo, _ = repository_factory_mocks + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + app_model = _app_model(tenant_id="tenant-1", id="app-1") + expected = MagicMock(name="pagination") + workflow_run_repo.get_paginated_workflow_runs.return_value = expected + args = {"limit": "7", "last_id": "last-1", "status": "succeeded"} + + # Act + result = service.get_paginate_workflow_runs( + app_model=app_model, + args=args, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Assert + assert result is expected + workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + limit=7, + last_id="last-1", + status="succeeded", + ) + + +def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + app_model = _app_model(tenant_id="tenant-1", id="app-1") + run_with_message = SimpleNamespace( + id="run-1", + status="running", + message=SimpleNamespace(id="msg-1", conversation_id="conv-1"), + ) + run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None) + pagination = SimpleNamespace(data=[run_with_message, run_without_message]) + monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination)) + + # Act + result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"}) + + # Assert + assert result is pagination + assert len(result.data) == 2 + assert result.data[0].message_id == "msg-1" + assert result.data[0].conversation_id == "conv-1" + assert result.data[0].status == "running" + assert not hasattr(result.data[1], "message_id") + assert result.data[1].id == "run-2" + + +def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], +) -> None: + # Arrange + _, workflow_run_repo, _ = repository_factory_mocks + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + app_model = _app_model(tenant_id="tenant-1", id="app-1") + expected = MagicMock(name="workflow_run") + workflow_run_repo.get_workflow_run_by_id.return_value = expected + + # Act + result = service.get_workflow_run(app_model=app_model, run_id="run-1") + + # Assert + assert result is expected + workflow_run_repo.get_workflow_run_by_id.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + run_id="run-1", + ) + + +def test_get_workflow_runs_count_should_forward_optional_filters( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], +) -> None: + # Arrange + _, workflow_run_repo, _ = repository_factory_mocks + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + app_model = _app_model(tenant_id="tenant-1", id="app-1") + expected = {"total": 3, "succeeded": 2} + workflow_run_repo.get_workflow_runs_count.return_value = expected + + # Act + result = service.get_workflow_runs_count( + app_model=app_model, + status="succeeded", + time_range="7d", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Assert + assert result == expected + workflow_run_repo.get_workflow_runs_count.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + status="succeeded", + time_range="7d", + ) + + +def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None)) + app_model = _app_model(id="app-1") + user = _account(current_tenant_id="tenant-1") + + # Act + result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user) + + # Assert + assert result == [] + + +def test_get_workflow_run_node_executions_should_use_end_user_tenant_id( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + node_repo, _, _ = repository_factory_mocks + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1"))) + + class FakeEndUser: + def __init__(self, tenant_id: str) -> None: + self.tenant_id = tenant_id + + monkeypatch.setattr(service_module, "EndUser", FakeEndUser) + user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user")) + app_model = _app_model(id="app-1") + expected = [SimpleNamespace(id="exec-1")] + node_repo.get_executions_by_workflow_run.return_value = expected + + # Act + result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user) + + # Assert + assert result == expected + node_repo.get_executions_by_workflow_run.assert_called_once_with( + tenant_id="tenant-end-user", + app_id="app-1", + workflow_run_id="run-1", + ) + + +def test_get_workflow_run_node_executions_should_use_account_current_tenant_id( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + node_repo, _, _ = repository_factory_mocks + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1"))) + app_model = _app_model(id="app-1") + user = _account(current_tenant_id="tenant-account") + expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")] + node_repo.get_executions_by_workflow_run.return_value = expected + + # Act + result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user) + + # Assert + assert result == expected + node_repo.get_executions_by_workflow_run.assert_called_once_with( + tenant_id="tenant-account", + app_id="app-1", + workflow_run_id="run-1", + ) + + +def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none( + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1"))) + app_model = _app_model(id="app-1") + user = _account(current_tenant_id=None) + + # Act / Assert + with pytest.raises(ValueError, match="tenant_id cannot be None"): + service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter_additional.py b/api/tests/unit_tests/services/workflow/test_workflow_converter_additional.py new file mode 100644 index 00000000000..2aaf3bdf1d5 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter_additional.py @@ -0,0 +1,831 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.helper import encrypter +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import Account, App, AppMode, AppModelConfig +from services.workflow import workflow_converter as converter_module +from services.workflow.workflow_converter import WorkflowConverter + +try: + from graphon.enums import BuiltinNodeTypes + from graphon.model_runtime.entities.llm_entities import LLMMode + from graphon.model_runtime.entities.message_entities import PromptMessageRole + from graphon.variables.input_entities import VariableEntity, VariableEntityType +except ModuleNotFoundError: + from dify_graph.enums import BuiltinNodeTypes + from dify_graph.model_runtime.entities.llm_entities import LLMMode + from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + from dify_graph.variables.input_entities import VariableEntity, VariableEntityType + + +@pytest.fixture +def converter() -> WorkflowConverter: + return WorkflowConverter() + + +def _app_model(**kwargs: Any) -> App: + return cast(App, SimpleNamespace(**kwargs)) + + +def _account(**kwargs: Any) -> Account: + return cast(Account, SimpleNamespace(**kwargs)) + + +def _app_model_config(**kwargs: Any) -> AppModelConfig: + return cast(AppModelConfig, SimpleNamespace(**kwargs)) + + +def _build_start_graph() -> dict[str, Any]: + return { + "nodes": [ + { + "id": "start", + "position": None, + "data": {"type": BuiltinNodeTypes.START, "variables": [{"variable": "name"}, {"variable": "city"}]}, + } + ], + "edges": [], + } + + +def _build_model_config(mode: str | LLMMode) -> ModelConfigEntity: + return ModelConfigEntity(provider="openai", model="gpt-4", mode=mode, parameters={}, stop=[]) + + +@pytest.fixture +def default_variables() -> list[VariableEntity]: + return [ + VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT), + VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH), + VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT), + ] + + +def test__convert_to_start_node(default_variables: list[VariableEntity]) -> None: + result = WorkflowConverter()._convert_to_start_node(default_variables) + + assert result["id"] == "start" + assert result["data"]["type"] == BuiltinNodeTypes.START + assert result["data"]["variables"][0]["type"] == "text-input" + assert result["data"]["variables"][0]["variable"] == "text_input" + + +def test__convert_to_http_request_node_for_chatbot(default_variables: list[VariableEntity]) -> None: + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.CHAT + + extension = APIBasedExtension( + tenant_id="tenant_id", + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + extension.id = "api_based_extension_id" + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=extension) + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={"api_based_extension_id": "api_based_extension_id"}, + ), + ] + + nodes, mapping = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables, + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == BuiltinNodeTypes.HTTP_REQUEST + assert nodes[1]["data"]["type"] == BuiltinNodeTypes.CODE + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY + assert body["params"]["query"] == "{{#sys.query#}}" + assert body["params"]["inputs"]["text_input"] == "{{#start.text_input#}}" + assert mapping == {"external_variable": "code_1"} + + +def test__convert_to_http_request_node_for_workflow_app(default_variables: list[VariableEntity]) -> None: + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.WORKFLOW + + extension = APIBasedExtension( + tenant_id="tenant_id", + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + extension.id = "api_based_extension_id" + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=extension) + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={"api_based_extension_id": "api_based_extension_id"}, + ), + ] + + nodes, _ = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables, + ) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "" + + +def test__convert_to_knowledge_retrieval_node_for_chatbot() -> None: + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.ADVANCED_CHAT, + dataset_config=dataset_config, + model_config=model_config, + ) + + assert node is not None + assert node["data"]["query_variable_selector"] == ["sys", "query"] + assert node["data"]["multiple_retrieval_config"]["top_k"] == 5 + + +def test__convert_to_knowledge_retrieval_node_for_workflow_app() -> None: + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable="query", + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.WORKFLOW, + dataset_config=dataset_config, + model_config=model_config, + ) + + assert node is not None + assert node["data"]["query_variable_selector"] == ["start", "query"] + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables: list[VariableEntity]) -> None: + workflow_converter = WorkflowConverter() + graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []} + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[]) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helper for {{text_input}} and {{paragraph}}", + ) + + node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=model_config, + graph=graph, + prompt_template=prompt_template, + ) + + assert node["data"]["type"] == BuiltinNodeTypes.LLM + assert node["data"]["memory"] is not None + assert node["data"]["prompt_template"][0]["role"] == "user" + assert "{{#start.text_input#}}" in node["data"]["prompt_template"][0]["text"] + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model_with_empty_template( + default_variables: list[VariableEntity], + monkeypatch: pytest.MonkeyPatch, +) -> None: + workflow_converter = WorkflowConverter() + graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []} + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[]) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="ignored", + ) + monkeypatch.setattr( + converter_module.SimplePromptTransform, + "get_prompt_template", + lambda self, **kwargs: {"prompt_template": PromptTemplateParser(""), "prompt_rules": {}}, + ) + + node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=model_config, + graph=graph, + prompt_template=prompt_template, + ) + + assert node["data"]["prompt_template"] == [] + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables: list[VariableEntity]) -> None: + workflow_converter = WorkflowConverter() + graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []} + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[]) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[AdvancedChatMessageEntity(text="Hello {{text_input}}", role=PromptMessageRole.USER)] + ), + ) + + node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=model_config, + graph=graph, + prompt_template=prompt_template, + ) + + assert isinstance(node["data"]["prompt_template"], list) + assert node["data"]["prompt_template"][0]["role"] == PromptMessageRole.USER.value + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model_without_template( + default_variables: list[VariableEntity], +) -> None: + workflow_converter = WorkflowConverter() + graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []} + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[]) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=None, + ) + + node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.WORKFLOW, + model_config=model_config, + graph=graph, + prompt_template=prompt_template, + ) + + assert node["data"]["prompt_template"] == [] + assert node["data"]["memory"] is None + + +def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables: list[VariableEntity]) -> None: + workflow_converter = WorkflowConverter() + graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []} + model_config = ModelConfigEntity( + provider="openai", + model="gpt-3.5-turbo-instruct", + mode=LLMMode.COMPLETION.value, + parameters={}, + stop=[], + ) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="Hello {{text_input}} and {{#query#}}", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), + ), + ) + + node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.COMPLETION, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=model_config, + graph=graph, + prompt_template=prompt_template, + ) + + assert node["data"]["prompt_template"]["text"].find("{{#sys.query#}}") != -1 + assert node["data"]["memory"]["role_prefix"]["user"] == "Human" + + +def test__convert_to_end_node() -> None: + node = WorkflowConverter()._convert_to_end_node() + assert node["id"] == "end" + assert node["data"]["type"] == BuiltinNodeTypes.END + + +def test__convert_to_answer_node() -> None: + node = WorkflowConverter()._convert_to_answer_node() + assert node["id"] == "answer" + assert node["data"]["type"] == BuiltinNodeTypes.ANSWER + + +def test_convert_to_workflow_should_raise_when_app_model_config_is_missing(converter: WorkflowConverter) -> None: + app_model = _app_model(app_model_config=None) + + with pytest.raises(ValueError, match="App model config is required"): + converter.convert_to_workflow( + app_model=app_model, + account=_account(id="account-1"), + name="new-app", + icon_type="emoji", + icon="robot", + icon_background="#fff", + ) + + +@pytest.mark.parametrize( + ("source_mode", "expected_mode"), + [ + (AppMode.CHAT, AppMode.ADVANCED_CHAT), + (AppMode.COMPLETION, AppMode.WORKFLOW), + ], +) +def test_convert_to_workflow_should_create_new_app_with_fallback_fields( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, + source_mode: AppMode, + expected_mode: AppMode, +) -> None: + class FakeApp: + def __init__(self) -> None: + self.id = "new-app-id" + + workflow = SimpleNamespace(app_id=None) + monkeypatch.setattr(converter, "convert_app_model_config_to_workflow", MagicMock(return_value=workflow)) + monkeypatch.setattr(converter_module, "App", FakeApp) + + db_session = SimpleNamespace(add=MagicMock(), flush=MagicMock(), commit=MagicMock()) + monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session)) + + send_mock = MagicMock() + monkeypatch.setattr(converter_module.app_was_created, "send", send_mock) + + account = _account(id="account-1") + app_model = _app_model( + tenant_id="tenant-1", + name="Source App", + mode=source_mode, + icon_type="emoji", + icon="sparkles", + icon_background="#123456", + enable_site=True, + enable_api=True, + api_rpm=10, + api_rph=100, + is_public=False, + app_model_config=_app_model_config(id="config-1"), + ) + + new_app = converter.convert_to_workflow( + app_model=app_model, + account=account, + name="", + icon_type="", + icon="", + icon_background="", + ) + + assert new_app.name == "Source App(workflow)" + assert new_app.mode == expected_mode + assert new_app.icon_type == "emoji" + assert new_app.icon == "sparkles" + assert new_app.icon_background == "#123456" + assert new_app.created_by == "account-1" + assert workflow.app_id == "new-app-id" + db_session.add.assert_called_once() + db_session.flush.assert_called_once() + db_session.commit.assert_called_once() + send_mock.assert_called_once_with(new_app, account=account) + + +def test_convert_app_model_config_to_workflow_should_build_advanced_chat_graph_and_features( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT) + app_config = SimpleNamespace( + variables=[SimpleNamespace(variable="name")], + external_data_variables=[SimpleNamespace(variable="ext")], + dataset=SimpleNamespace(id="dataset"), + model=SimpleNamespace(), + prompt_template=SimpleNamespace(), + additional_features=SimpleNamespace(file_upload=SimpleNamespace()), + app_model_config_dict={ + "opening_statement": "hello", + "suggested_questions": ["q1"], + "suggested_questions_after_answer": True, + "speech_to_text": True, + "text_to_speech": {"enabled": True}, + "file_upload": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": True}, + "retriever_resource": {"enabled": True}, + }, + ) + + class FakeWorkflow: + VERSION_DRAFT = "draft" + + def __init__(self, **kwargs: Any) -> None: + self.__dict__.update(kwargs) + + monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.ADVANCED_CHAT)) + monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config)) + monkeypatch.setattr( + converter, + "_convert_to_start_node", + MagicMock( + return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}} + ), + ) + monkeypatch.setattr( + converter, + "_convert_to_http_request_node", + MagicMock( + return_value=( + [{"id": "http", "position": None, "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}], + {"ext": "code_1"}, + ) + ), + ) + monkeypatch.setattr( + converter, + "_convert_to_knowledge_retrieval_node", + MagicMock( + return_value={"id": "knowledge", "position": None, "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}} + ), + ) + monkeypatch.setattr( + converter, + "_convert_to_llm_node", + MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}), + ) + monkeypatch.setattr( + converter, + "_convert_to_answer_node", + MagicMock(return_value={"id": "answer", "position": None, "data": {"type": BuiltinNodeTypes.ANSWER}}), + ) + monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow) + + db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock()) + monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session)) + + workflow = converter.convert_app_model_config_to_workflow( + app_model=app_model, + app_model_config=_app_model_config(id="cfg"), + account_id="account-1", + ) + + graph = json.loads(workflow.graph) + node_ids = [node["id"] for node in graph["nodes"]] + assert node_ids == ["start", "http", "knowledge", "llm", "answer"] + + features = json.loads(workflow.features) + assert "opening_statement" in features + assert "retriever_resource" in features + db_session.add.assert_called_once() + db_session.commit.assert_called_once() + + +def test_convert_app_model_config_to_workflow_should_build_workflow_mode_with_end_node( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.COMPLETION) + app_config = SimpleNamespace( + variables=[SimpleNamespace(variable="name")], + external_data_variables=[], + dataset=SimpleNamespace(id="dataset"), + model=SimpleNamespace(), + prompt_template=SimpleNamespace(), + additional_features=None, + app_model_config_dict={ + "text_to_speech": {"enabled": False}, + "file_upload": {"enabled": False}, + "sensitive_word_avoidance": {"enabled": False}, + }, + ) + + class FakeWorkflow: + VERSION_DRAFT = "draft" + + def __init__(self, **kwargs: Any) -> None: + self.__dict__.update(kwargs) + + monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.WORKFLOW)) + monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config)) + monkeypatch.setattr( + converter, + "_convert_to_start_node", + MagicMock( + return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}} + ), + ) + monkeypatch.setattr(converter, "_convert_to_knowledge_retrieval_node", MagicMock(return_value=None)) + monkeypatch.setattr( + converter, + "_convert_to_llm_node", + MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}), + ) + monkeypatch.setattr( + converter, + "_convert_to_end_node", + MagicMock(return_value={"id": "end", "position": None, "data": {"type": BuiltinNodeTypes.END}}), + ) + monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow) + + db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock()) + monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session)) + + workflow = converter.convert_app_model_config_to_workflow( + app_model=app_model, + app_model_config=_app_model_config(id="cfg"), + account_id="account-1", + ) + + graph = json.loads(workflow.graph) + node_ids = [node["id"] for node in graph["nodes"]] + assert node_ids == ["start", "llm", "end"] + + features = json.loads(workflow.features) + assert set(features.keys()) == {"text_to_speech", "file_upload", "sensitive_word_avoidance"} + + +def test_convert_to_app_config_should_route_to_correct_manager( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent_result = SimpleNamespace(kind="agent") + chat_result = SimpleNamespace(kind="chat") + completion_result = SimpleNamespace(kind="completion") + monkeypatch.setattr( + converter_module.AgentChatAppConfigManager, "get_app_config", MagicMock(return_value=agent_result) + ) + monkeypatch.setattr(converter_module.ChatAppConfigManager, "get_app_config", MagicMock(return_value=chat_result)) + monkeypatch.setattr( + converter_module.CompletionAppConfigManager, + "get_app_config", + MagicMock(return_value=completion_result), + ) + + from_agent_mode = converter._convert_to_app_config( + app_model=_app_model(mode=AppMode.AGENT_CHAT, is_agent=False), + app_model_config=_app_model_config(id="cfg-1"), + ) + from_agent_flag = converter._convert_to_app_config( + app_model=_app_model(mode=AppMode.CHAT, is_agent=True), + app_model_config=_app_model_config(id="cfg-2"), + ) + from_chat_mode = converter._convert_to_app_config( + app_model=_app_model(mode=AppMode.CHAT, is_agent=False), + app_model_config=_app_model_config(id="cfg-3"), + ) + from_completion_mode = converter._convert_to_app_config( + app_model=_app_model(mode=AppMode.COMPLETION, is_agent=False), + app_model_config=_app_model_config(id="cfg-4"), + ) + + assert from_agent_mode is agent_result + assert from_agent_flag is agent_result + assert from_chat_mode is chat_result + assert from_completion_mode is completion_result + + +def test_convert_to_app_config_should_raise_for_invalid_app_mode(converter: WorkflowConverter) -> None: + app_model = _app_model(mode=AppMode.WORKFLOW, is_agent=False) + + with pytest.raises(ValueError, match="Invalid app mode"): + converter._convert_to_app_config(app_model=app_model, app_model_config=_app_model_config(id="cfg")) + + +def test_convert_to_http_request_node_should_skip_non_api_and_missing_extension_id( + converter: WorkflowConverter, +) -> None: + app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT) + external_data_variables = [ + ExternalDataVariableEntity(variable="skip_type", type="dataset", config={"api_based_extension_id": "x"}), + ExternalDataVariableEntity(variable="skip_config", type="api", config={}), + ] + + nodes, mapping = converter._convert_to_http_request_node( + app_model=app_model, + variables=[], + external_data_variables=external_data_variables, + ) + + assert nodes == [] + assert mapping == {} + + +def test_convert_to_knowledge_retrieval_node_should_return_none_for_workflow_without_query_variable( + converter: WorkflowConverter, +) -> None: + dataset_config = DatasetEntity( + dataset_ids=["ds-1"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=None, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + model_config = _build_model_config(mode=LLMMode.CHAT) + + node = converter._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.WORKFLOW, + dataset_config=dataset_config, + model_config=model_config, + ) + + assert node is None + + +def test_convert_to_llm_node_should_raise_when_simple_chat_template_missing( + converter: WorkflowConverter, +) -> None: + graph = _build_start_graph() + model_config = _build_model_config(mode=LLMMode.CHAT) + prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE) + + with pytest.raises(ValueError, match="Simple prompt template is required"): + converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + graph=graph, + model_config=model_config, + prompt_template=prompt_template, + ) + + +def test_convert_to_llm_node_should_raise_when_prompt_template_parser_type_is_invalid_for_chat( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + graph = _build_start_graph() + model_config = _build_model_config(mode=LLMMode.CHAT) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="Hello {{name}}", + ) + monkeypatch.setattr( + converter_module.SimplePromptTransform, + "get_prompt_template", + lambda self, **kwargs: {"prompt_template": "invalid"}, + ) + + with pytest.raises(TypeError, match="Expected PromptTemplateParser"): + converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + graph=graph, + model_config=model_config, + prompt_template=prompt_template, + ) + + +def test_convert_to_llm_node_should_raise_when_simple_completion_template_missing( + converter: WorkflowConverter, +) -> None: + graph = _build_start_graph() + model_config = _build_model_config(mode=LLMMode.COMPLETION) + prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE) + + with pytest.raises(ValueError, match="Simple prompt template is required"): + converter._convert_to_llm_node( + original_app_mode=AppMode.COMPLETION, + new_app_mode=AppMode.WORKFLOW, + graph=graph, + model_config=model_config, + prompt_template=prompt_template, + ) + + +def test_convert_to_llm_node_should_raise_when_completion_prompt_rules_type_is_invalid( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + graph = _build_start_graph() + model_config = _build_model_config(mode=LLMMode.COMPLETION) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="Hello {{name}}", + ) + monkeypatch.setattr( + converter_module.SimplePromptTransform, + "get_prompt_template", + lambda self, **kwargs: {"prompt_template": PromptTemplateParser("Hello {{name}}"), "prompt_rules": "invalid"}, + ) + + with pytest.raises(TypeError, match="Expected dict for prompt_rules"): + converter._convert_to_llm_node( + original_app_mode=AppMode.COMPLETION, + new_app_mode=AppMode.ADVANCED_CHAT, + graph=graph, + model_config=model_config, + prompt_template=prompt_template, + ) + + +def test_convert_to_llm_node_should_use_empty_text_for_advanced_completion_without_template( + converter: WorkflowConverter, +) -> None: + graph = _build_start_graph() + model_config = _build_model_config(mode=LLMMode.COMPLETION) + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=None, + ) + + llm_node = converter._convert_to_llm_node( + original_app_mode=AppMode.COMPLETION, + new_app_mode=AppMode.WORKFLOW, + graph=graph, + model_config=model_config, + prompt_template=prompt_template, + ) + + assert llm_node["data"]["prompt_template"]["text"] == "" + assert llm_node["data"]["memory"] is None + + +def test_replace_template_variables_should_replace_start_and_external_references(converter: WorkflowConverter) -> None: + template = "Hello {{name}} from {{city}} with {{weather}}" + variables = [{"variable": "name"}, {"variable": "city"}] + external_mapping = {"weather": "code_1"} + + result = converter._replace_template_variables(template, variables, external_mapping) + + assert result == "Hello {{#start.name#}} from {{#start.city#}} with {{#code_1.result#}}" + + +def test_graph_helpers_should_create_edges_append_nodes_and_choose_mode(converter: WorkflowConverter) -> None: + graph = {"nodes": [{"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START}}], "edges": []} + node = {"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}} + + edge = converter._create_edge("start", "llm") + updated_graph = converter._append_node(graph, node) + workflow_mode = converter._get_new_app_mode(_app_model(mode=AppMode.COMPLETION)) + advanced_chat_mode = converter._get_new_app_mode(_app_model(mode=AppMode.CHAT)) + + assert edge == {"id": "start-llm", "source": "start", "target": "llm"} + assert updated_graph["nodes"][-1]["id"] == "llm" + assert updated_graph["edges"][-1]["source"] == "start" + assert workflow_mode == AppMode.WORKFLOW + assert advanced_chat_mode == AppMode.ADVANCED_CHAT + + +def test_get_api_based_extension_should_raise_when_extension_not_found( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + db_session = SimpleNamespace(scalar=MagicMock(return_value=None)) + monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session)) + + with pytest.raises(ValueError, match="API Based Extension not found"): + converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1") + db_session.scalar.assert_called_once() + + +def test_get_api_based_extension_should_return_entity_when_found( + converter: WorkflowConverter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + extension = SimpleNamespace(id="ext-1") + db_session = SimpleNamespace(scalar=MagicMock(return_value=extension)) + monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session)) + + result = converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1") + + assert result is extension + db_session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 077a7c27a2b..b8b073f75c0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -1,10 +1,9 @@ -from __future__ import annotations - import json import queue from collections.abc import Sequence from dataclasses import dataclass from datetime import UTC, datetime +from itertools import cycle from threading import Event import pytest @@ -224,3 +223,577 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) - buffer_state.task_id_ready.set() task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0) assert task_id == expected + + +# === Merged from test_workflow_event_snapshot_service_additional.py === + + +import json +import queue +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import UTC, datetime +from threading import Event +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from sqlalchemy.orm import Session, sessionmaker + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.task_entities import StreamEvent +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from models.enums import CreatorUserRole +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.entities.workflow_pause import WorkflowPauseEntity +from services import workflow_event_snapshot_service as service_module +from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream + + +def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"query": "hello"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=1.2, + total_tokens=5, + total_steps=2, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.outputs = {"answer": "ok"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +class _SessionContext: + def __init__(self, session: Any) -> None: + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +class _SessionMaker: + def __init__(self, session: Any) -> None: + self._session = session + + def __call__(self) -> _SessionContext: + return _SessionContext(self._session) + + +class _SubscriptionContext: + def __init__(self, subscription: Any) -> None: + self._subscription = subscription + + def __enter__(self) -> Any: + return self._subscription + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +class _Topic: + def __init__(self, subscription: Any) -> None: + self._subscription = subscription + + def subscribe(self) -> _SubscriptionContext: + return _SubscriptionContext(self._subscription) + + +class _StaticSubscription: + def receive(self, timeout: int = 1) -> None: + return None + + +@dataclass(frozen=True) +class _PauseEntity(WorkflowPauseEntity): + state: bytes + + @property + def id(self) -> str: + return "pause-1" + + @property + def workflow_execution_id(self) -> str: + return "run-1" + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return datetime(2024, 1, 1, tzinfo=UTC) + + def get_state(self) -> bytes: + return self.state + + def get_pause_reasons(self) -> list[Any]: + return [] + + +def test_get_message_context_should_return_none_when_no_message() -> None: + # Arrange + session = SimpleNamespace(scalar=MagicMock(return_value=None)) + session_maker = _SessionMaker(session) + + # Act + result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1") + + # Assert + assert result is None + + +def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None: + # Arrange + message = SimpleNamespace( + id="msg-1", + conversation_id="conv-1", + created_at=None, + answer="answer", + ) + session = SimpleNamespace(scalar=MagicMock(return_value=message)) + session_maker = _SessionMaker(session) + + # Act + result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1") + + # Assert + assert result is not None + assert result.created_at == 0 + assert result.message_id == "msg-1" + assert result.conversation_id == "conv-1" + assert result.answer == "answer" + + +def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None: + # Arrange + + # Act + result = service_module._load_resumption_context(None) + + # Assert + assert result is None + + +def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None: + # Arrange + pause_entity = _PauseEntity(state=b"not-a-valid-state") + + # Act + result = service_module._load_resumption_context(pause_entity) + + # Assert + assert result is None + + +def test_load_resumption_context_should_parse_valid_state_into_context() -> None: + # Arrange + context = _build_resumption_context_additional(task_id="task-ctx") + pause_entity = _PauseEntity(state=context.dumps().encode()) + + # Act + result = service_module._load_resumption_context(pause_entity) + + # Assert + assert result is not None + assert result.get_generate_entity().task_id == "task-ctx" + + +def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None: + # Arrange + + # Act + result = service_module._resolve_task_id( + resumption_context=None, + buffer_state=None, + workflow_run_id="run-1", + ) + + # Assert + assert result == "run-1" + + +@pytest.mark.parametrize( + ("payload", "expected"), + [ + (b'{"event":"node_started"}', {"event": "node_started"}), + (b"invalid-json", None), + (b"[]", None), + ], +) +def test_parse_event_message_should_parse_only_json_object( + payload: bytes, + expected: dict[str, Any] | None, +) -> None: + # Arrange + + # Act + result = service_module._parse_event_message(payload) + + # Assert + assert result == expected + + +def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None: + # Arrange + finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value} + paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value} + + # Act + is_finished = service_module._is_terminal_event(finished_event, include_paused=False) + paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False) + paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True) + + # Assert + assert is_finished is True + assert paused_without_flag is False + assert paused_with_flag is True + assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False + + +def test_apply_message_context_should_update_payload_when_context_exists() -> None: + # Arrange + payload: dict[str, Any] = {"event": "workflow_started"} + context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000) + + # Act + service_module._apply_message_context(payload, context) + + # Assert + assert payload["conversation_id"] == "conv-1" + assert payload["message_id"] == "msg-1" + assert payload["created_at"] == 1700000000 + + +def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None: + # Arrange + class Subscription: + def __init__(self) -> None: + self._calls = 0 + + def receive(self, timeout: int = 1) -> bytes | None: + self._calls += 1 + if self._calls == 1: + return b'{"event":"node_started","task_id":"task-1"}' + return None + + subscription = Subscription() + + # Act + buffer_state = service_module._start_buffering(subscription) + ready = buffer_state.task_id_ready.wait(timeout=1) + event = buffer_state.queue.get(timeout=1) + buffer_state.stop_event.set() + finished = buffer_state.done_event.wait(timeout=1) + + # Assert + assert ready is True + assert finished is True + assert buffer_state.task_id_hint == "task-1" + assert event["event"] == "node_started" + + +def test_start_buffering_should_drop_old_event_when_queue_is_full( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + class QueueWithSingleFull: + def __init__(self) -> None: + self._first_put = True + self.items: list[dict[str, Any]] = [{"event": "old"}] + + def put_nowait(self, item: dict[str, Any]) -> None: + if self._first_put: + self._first_put = False + raise queue.Full + self.items.append(item) + + def get_nowait(self) -> dict[str, Any]: + if not self.items: + raise queue.Empty + return self.items.pop(0) + + def empty(self) -> bool: + return len(self.items) == 0 + + fake_queue = QueueWithSingleFull() + monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue) + + class Subscription: + def __init__(self) -> None: + self._calls = 0 + + def receive(self, timeout: int = 1) -> bytes | None: + self._calls += 1 + if self._calls == 1: + return b'{"event":"node_started","task_id":"task-2"}' + return None + + subscription = Subscription() + + # Act + buffer_state = service_module._start_buffering(subscription) + ready = buffer_state.task_id_ready.wait(timeout=1) + buffer_state.stop_event.set() + finished = buffer_state.done_event.wait(timeout=1) + + # Assert + assert ready is True + assert finished is True + assert fake_queue.items[-1]["task_id"] == "task-2" + + +def test_start_buffering_should_set_done_event_when_subscription_raises() -> None: + # Arrange + class Subscription: + def receive(self, timeout: int = 1) -> bytes | None: + raise RuntimeError("subscription failure") + + subscription = Subscription() + + # Act + buffer_state = service_module._start_buffering(subscription) + finished = buffer_state.done_event.wait(timeout=1) + + # Assert + assert finished is True + + +def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock()) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr( + service_module, + "_get_message_context", + MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)), + ) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + monkeypatch.setattr( + service_module, + "_build_snapshot_events", + MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]), + ) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.ADVANCED_CHAT, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + ) + ) + + # Assert + assert events[0] == StreamEvent.PING.value + finished_event = cast(Mapping[str, Any], events[1]) + assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value + assert buffer_state.stop_event.is_set() is True + node_repo.get_execution_snapshots_by_workflow_run.assert_called_once() + called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs + assert called_kwargs["workflow_run_id"] == "run-1" + + +def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock()) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[])) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + + class AlwaysEmptyQueue: + def empty(self) -> bool: + return False + + def get(self, timeout: int = 1) -> None: + raise queue.Empty + + buffer_state = BufferState( + queue=AlwaysEmptyQueue(), # type: ignore[arg-type] + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + time_values = cycle([0.0, 6.0, 21.0, 26.0]) + monkeypatch.setattr(service_module.time, "time", lambda: next(time_values)) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + idle_timeout=20.0, + ping_interval=5.0, + ) + ) + + # Assert + assert events == [StreamEvent.PING.value, StreamEvent.PING.value] + assert buffer_state.stop_event.is_set() is True + + +def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock()) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[])) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + buffer_state.done_event.set() + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + ) + ) + + # Assert + assert events == [StreamEvent.PING.value] + assert buffer_state.stop_event.is_set() is True + + +def test_build_workflow_event_stream_should_continue_when_pause_loading_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED) + topic = _Topic(_StaticSubscription()) + workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom"))) + node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[])) + factory = SimpleNamespace( + create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo), + create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo), + ) + monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory) + monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic)) + monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None)) + monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1")) + snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}]) + monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder) + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint="task-1", + ) + monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state)) + + # Act + events = list( + build_workflow_event_stream( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=MagicMock(), + ) + ) + + # Assert + assert events[0] == StreamEvent.PING.value + assert snapshot_builder.call_args.kwargs["pause_entity"] is None diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 0b189ebae29..34e474c9218 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -10,6 +10,8 @@ This module tests the document indexing task functionality including: """ import uuid +from contextlib import nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest @@ -1113,13 +1115,17 @@ class TestAdvancedScenarios: _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) # Assert - # Verify delete was called to clean up task key - mock_redis.delete.assert_called_once() + expected_task_key = f"tenant_document_indexing_task:{tenant_id}" - # Verify the correct key was deleted (contains tenant_id and "document_indexing") - delete_call_args = mock_redis.delete.call_args[0][0] - assert tenant_id in delete_call_args - assert "document_indexing" in delete_call_args + # Verify the task key for this tenant was deleted (do not assert call count; fixtures may be shared). + mock_redis.delete.assert_any_call(expected_task_key) + + deleted_keys = [delete_call.args[0] for delete_call in mock_redis.delete.call_args_list if delete_call.args] + assert expected_task_key in deleted_keys + + deleted_task_key = next(key for key in deleted_keys if key == expected_task_key) + assert tenant_id in deleted_task_key + assert "document_indexing" in deleted_task_key def test_billing_disabled_skips_limit_checks( self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service @@ -1510,3 +1516,475 @@ class TestRobustness: # Verify the exception message assert "Feature service" in str(exc_info.value) or isinstance(exc_info.value, Exception) + + +class _SessionContext: + def __init__(self, session: MagicMock) -> None: + self._session = session + + def __enter__(self) -> MagicMock: + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override] + return None + + +class TestDocumentIndexingTaskSummaryFlow: + """Additional coverage for summary and tenant queue branches.""" + + def test_should_return_when_dataset_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test early return when dataset does not exist.""" + # Arrange + session = MagicMock() + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = None + session.query.side_effect = lambda model: dataset_query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) + features_mock = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.FeatureService.get_features", features_mock) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + features_mock.assert_not_called() + + def test_should_mark_documents_error_when_batch_upload_limit_exceeded( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Test batch upload limit triggers error handling.""" + # Arrange + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None) + + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + document_query = MagicMock() + document_query.where.return_value = document_query + document_query.first.return_value = document + + session = MagicMock() + session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query + + monkeypatch.setattr( + "tasks.document_indexing_task.session_factory.create_session", + MagicMock(return_value=_SessionContext(session)), + ) + + features = SimpleNamespace( + billing=SimpleNamespace( + enabled=True, + subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL), + ), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + monkeypatch.setattr("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", "1") + + # Act + _document_indexing("dataset-1", ["doc-1", "doc-2"]) + + # Assert + assert document.indexing_status == "error" + assert "batch upload limit" in document.error + session.commit.assert_called_once() + + def test_should_queue_summary_generation_for_completed_documents(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test summary generation is queued for eligible documents.""" + # Arrange + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + doc_eligible = SimpleNamespace( + id="doc-1", + indexing_status="completed", + doc_form="text", + need_summary=True, + ) + doc_skip_form = SimpleNamespace( + id="doc-2", + indexing_status="completed", + doc_form="qa_model", + need_summary=True, + ) + doc_skip_status = SimpleNamespace( + id="doc-3", + indexing_status="processing", + doc_form="text", + need_summary=True, + ) + + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")] + phase1_document_query = MagicMock() + phase1_document_query.where.return_value = phase1_document_query + phase1_document_query.all.return_value = phase1_docs + + summary_document_query = MagicMock() + summary_document_query.where.return_value = summary_document_query + summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session3 = MagicMock() + + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: phase1_document_query + session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query + + create_session_mock = MagicMock( + side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)] + ) + monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + + indexing_runner = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner)) + delay_mock = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock) + + # Act + _document_indexing("dataset-1", ["doc-1", "doc-2", "doc-3"]) + + # Assert + delay_mock.assert_called_once_with("dataset-1", "doc-1", None) + + def test_should_continue_when_summary_queue_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test summary queueing errors are swallowed.""" + # Arrange + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + doc_eligible = SimpleNamespace( + id="doc-1", + indexing_status="completed", + doc_form="text", + need_summary=True, + ) + + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + phase1_query = MagicMock() + phase1_query.where.return_value = phase1_query + phase1_query.all.return_value = [SimpleNamespace(id="doc-1")] + + summary_query = MagicMock() + summary_query.where.return_value = summary_query + summary_query.all.return_value = [doc_eligible] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session3 = MagicMock() + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: phase1_query + session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query + + monkeypatch.setattr( + "tasks.document_indexing_task.session_factory.create_session", + MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]), + ) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + + indexing_runner = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner)) + delay_mock = MagicMock(side_effect=Exception("boom")) + monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + delay_mock.assert_called_once_with("dataset-1", "doc-1", None) + + def test_should_return_when_dataset_missing_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test early return when dataset is missing after indexing.""" + # Arrange + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.side_effect = [dataset, None] + + document_query = MagicMock() + document_query.where.return_value = document_query + document_query.all.return_value = [SimpleNamespace(id="doc-1")] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session3 = MagicMock() + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: document_query + session3.query.side_effect = lambda model: dataset_query + + monkeypatch.setattr( + "tasks.document_indexing_task.session_factory.create_session", + MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]), + ) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock())) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + session3.query.assert_called() + + def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test summary generation skipped when indexing_technique is not high_quality.""" + # Arrange + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="economy", + summary_index_setting={"enable": True}, + ) + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + document_query = MagicMock() + document_query.where.return_value = document_query + document_query.all.return_value = [SimpleNamespace(id="doc-1")] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session3 = MagicMock() + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: document_query + session3.query.side_effect = lambda model: dataset_query + + monkeypatch.setattr( + "tasks.document_indexing_task.session_factory.create_session", + MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]), + ) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock())) + + delay_mock = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + delay_mock.assert_not_called() + + def test_should_skip_summary_generation_when_indexing_paused(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test summary generation is skipped when indexing is paused.""" + # Arrange + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + document_query = MagicMock() + document_query.where.return_value = document_query + document_query.all.return_value = [SimpleNamespace(id="doc-1")] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: document_query + + create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]) + monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + + runner = MagicMock() + runner.run.side_effect = DocumentIsPausedError("paused") + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner)) + delay_mock = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + delay_mock.assert_not_called() + + def test_should_handle_indexing_runner_exception(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test generic indexing runner exception is handled.""" + # Arrange + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + document_query = MagicMock() + document_query.where.return_value = document_query + document_query.all.return_value = [SimpleNamespace(id="doc-1")] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: document_query + + monkeypatch.setattr( + "tasks.document_indexing_task.session_factory.create_session", + MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]), + ) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + + runner = MagicMock() + runner.run.side_effect = RuntimeError("boom") + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner)) + + delay_mock = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + delay_mock.assert_not_called() + + def test_should_log_missing_document_entry_in_summary_list(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test falsey document entries are handled in summary iteration.""" + + # Arrange + class _FalseyDocument: + def __init__(self, doc_id: str) -> None: + self.id = doc_id + + def __bool__(self) -> bool: + return False + + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + dataset_query = MagicMock() + dataset_query.where.return_value = dataset_query + dataset_query.first.return_value = dataset + + phase1_query = MagicMock() + phase1_query.where.return_value = phase1_query + phase1_query.all.return_value = [SimpleNamespace(id="doc-1")] + + summary_query = MagicMock() + summary_query.where.return_value = summary_query + summary_query.all.return_value = [_FalseyDocument("missing-doc")] + + session1 = MagicMock() + session2 = MagicMock() + session2.begin.return_value = nullcontext() + session3 = MagicMock() + session1.query.side_effect = lambda model: dataset_query + session2.query.side_effect = lambda model: phase1_query + session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query + + monkeypatch.setattr( + "tasks.document_indexing_task.session_factory.create_session", + MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]), + ) + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False), + vector_space=SimpleNamespace(limit=0, size=0), + ) + monkeypatch.setattr( + "tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features) + ) + monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock())) + + delay_mock = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock) + + # Act + _document_indexing("dataset-1", ["doc-1"]) + + # Assert + delay_mock.assert_not_called() + + def test_normal_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test normal indexing task delegates to tenant queue handler.""" + # Arrange + handler = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler) + + # Act + normal_document_indexing_task("tenant-1", "dataset-1", ["doc-1"]) + + # Assert + handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], normal_document_indexing_task) + + def test_priority_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test priority indexing task delegates to tenant queue handler.""" + # Arrange + handler = MagicMock() + monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler) + + # Act + priority_document_indexing_task("tenant-1", "dataset-1", ["doc-1"]) + + # Assert + handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], priority_document_indexing_task)