Merge branch 'main' into 3-18-no-global-loading

This commit is contained in:
Stephen Zhou
2026-04-02 16:48:28 +08:00
committed by GitHub
10 changed files with 4180 additions and 34 deletions

View File

@@ -129,6 +129,7 @@ class VariableTruncator(BaseTruncator):
used_size += self.calculate_json_size(key)
if used_size > budget:
truncated_mapping[key] = "..."
is_truncated = True
continue
value_budget = (budget - used_size) // (length - len(truncated_mapping))
if isinstance(value, Segment):
@@ -164,9 +165,9 @@ class VariableTruncator(BaseTruncator):
result = self._truncate_segment(segment, self._max_size_bytes)
if result.value_size > self._max_size_bytes:
if isinstance(result.value, str):
result = self._truncate_string(result.value, self._max_size_bytes)
return TruncationResult(StringSegment(value=result.value), True)
if isinstance(result.value, StringSegment):
fallback_result = self._truncate_string(result.value.value, self._max_size_bytes)
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
json_str = dumps_with_segments(result.value, ensure_ascii=False)

View File

@@ -85,3 +85,644 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
assert len(custom_models) == 1
# The sanitizer should drop credentials in list response
assert custom_models[0].credentials is None
# === Merged from test_model_provider_service.py ===
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.entities.model_entities import ModelStatus
from models.provider import ProviderType
from services import model_provider_service as service_module
from services.errors.app_model_config import ProviderNotFoundError
from services.model_provider_service import ModelProviderService
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
manager = MagicMock()
service = ModelProviderService()
service._get_provider_manager = MagicMock(return_value=manager)
return service, manager
def _build_provider_configuration(
*,
provider_name: str = "openai",
supported_model_types: list[ModelType] | None = None,
custom_models: list[Any] | None = None,
custom_config_available: bool = True,
) -> SimpleNamespace:
if supported_model_types is None:
supported_model_types = [ModelType.LLM]
return SimpleNamespace(
provider=SimpleNamespace(
provider=provider_name,
label=I18nObject(en_US=provider_name),
description=None,
icon_small=None,
icon_small_dark=None,
background=None,
help=None,
supported_model_types=supported_model_types,
configurate_methods=[],
provider_credential_schema=None,
model_credential_schema=None,
),
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=SimpleNamespace(
provider=SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="Credential 1",
available_credentials=[],
),
models=custom_models,
can_added_models=[],
),
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
is_custom_configuration_available=lambda: custom_config_available,
)
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
provider_configuration = SimpleNamespace(name="provider-config")
manager.get_configurations.return_value = {"openai": provider_configuration}
# Act
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
# Assert
assert result is provider_configuration
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_configurations.return_value = {}
# Act / Assert
with pytest.raises(ProviderNotFoundError, match="does not exist"):
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
allowed = _build_provider_configuration(
provider_name="openai",
supported_model_types=[ModelType.LLM],
custom_config_available=False,
)
filtered = _build_provider_configuration(
provider_name="embedding",
supported_model_types=[ModelType.TEXT_EMBEDDING],
custom_config_available=True,
)
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
# Act
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert len(result) == 1
assert result[0].provider == "openai"
assert result[0].custom_configuration.status.value == "no-configure"
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
class _Model:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def model_dump(self) -> dict[str, Any]:
return {
"model": self.model_name,
"label": {"en_US": self.model_name},
"model_type": ModelType.LLM,
"features": [],
"fetch_from": FetchFrom.PREDEFINED_MODEL,
"model_properties": {},
"deprecated": False,
"status": ModelStatus.ACTIVE,
"load_balancing_enabled": False,
"has_invalid_load_balancing_configs": False,
"provider": {
"provider": "openai",
"label": {"en_US": "OpenAI"},
"icon_small": None,
"icon_small_dark": None,
"supported_model_types": [ModelType.LLM],
},
}
provider_configurations = SimpleNamespace(
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
)
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
# Assert
assert len(result) == 2
assert result[0].model == "gpt-4o"
assert result[1].provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(provider="openai")
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
[
(
"get_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"get_provider_credential",
{"credential_id": "cred-1"},
{"token": "abc"},
),
(
"validate_provider_credentials",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
"validate_provider_credentials",
({"token": "abc"},),
None,
),
(
"create_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
"create_provider_credential",
({"token": "abc"}, "A"),
None,
),
(
"update_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_id": "cred-1",
"credential_name": "B",
},
"update_provider_credential",
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
None,
),
(
"remove_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"delete_provider_credential",
{"credential_id": "cred-1"},
None,
),
(
"switch_active_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"switch_active_provider_credential",
{"credential_id": "cred-1"},
None,
),
],
)
def test_provider_credential_methods_should_delegate_to_provider_configuration(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
provider_call_kwargs: Any,
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
[
(
"get_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"get_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
{"api_key": "x"},
),
(
"validate_model_credentials",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
},
"validate_custom_model_credentials",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
None,
),
(
"create_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
"create_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
None,
),
(
"update_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
"update_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
None,
),
(
"remove_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"delete_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"switch_active_custom_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"switch_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"add_model_credential_to_model_list",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"add_model_credential_to_model",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"remove_model",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
},
"delete_custom_model",
{"model_type": ModelType.LLM, "model": "gpt-4o"},
None,
),
],
)
def test_custom_model_methods_should_convert_model_type_and_delegate(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
expected_kwargs: dict[str, Any],
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
if method_name == "get_model_credential":
assert result == {"api_key": "x"}
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
openai_provider = SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
icon_small_dark=None,
)
anthropic_provider = SimpleNamespace(
provider="anthropic",
label=I18nObject(en_US="Anthropic"),
icon_small=None,
icon_small_dark=None,
)
models = [
SimpleNamespace(
provider=openai_provider,
model="gpt-4o",
label=I18nObject(en_US="GPT-4o"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=False,
),
SimpleNamespace(
provider=openai_provider,
model="old-openai",
label=I18nObject(en_US="Old OpenAI"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
SimpleNamespace(
provider=anthropic_provider,
model="old-anthropic",
label=I18nObject(en_US="Old Anthropic"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
]
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
assert len(result) == 1
assert result[0].provider == "openai"
assert len(result[0].models) == 1
assert result[0].models[0].model == "gpt-4o"
@pytest.mark.parametrize(
("credentials", "schema", "expected_count"),
[
(None, None, 0),
({"api_key": "x"}, None, 0),
(
{"api_key": "x"},
SimpleNamespace(
parameter_rules=[
ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
)
]
),
1,
),
],
)
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
credentials: dict[str, Any] | None,
schema: Any,
expected_count: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
provider_configuration.get_current_credentials.return_value = credentials
provider_configuration.get_model_schema.return_value = schema
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
# Assert
assert len(result) == expected_count
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
if credentials:
provider_configuration.get_model_schema.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
credentials=credentials,
)
else:
provider_configuration.get_model_schema.assert_not_called()
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = SimpleNamespace(
model="gpt-4o",
model_type=ModelType.LLM,
provider=SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
supported_model_types=[ModelType.LLM],
),
)
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is not None
assert result.model == "gpt-4o"
assert result.provider.provider == "openai"
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = None
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.side_effect = RuntimeError("boom")
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
# Act
service.update_default_model_of_model_type(
tenant_id="tenant-1",
model_type=ModelType.LLM.value,
provider="openai",
model="gpt-4o",
)
# Assert
manager.update_default_model_record.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.LLM,
provider="openai",
model="gpt-4o",
)
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
factory_instance = MagicMock()
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
factory_constructor = MagicMock(return_value=factory_instance)
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
# Act
result = service.get_model_provider_icon(
tenant_id="tenant-1",
provider="openai",
icon_type="icon_small",
lang="en_US",
)
# Assert
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
assert result == (b"icon-bytes", "image/png")
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
service.switch_preferred_provider(
tenant_id="tenant-1",
provider="openai",
preferred_provider_type=ProviderType.SYSTEM.value,
)
# Assert
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
@pytest.mark.parametrize(
("method_name", "provider_method_name"),
[
("enable_model", "enable_model"),
("disable_model", "disable_model"),
],
)
def test_model_enablement_methods_should_convert_model_type_and_delegate(
method_name: str,
provider_method_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
getattr(service, method_name)(
tenant_id="tenant-1",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM.value,
)
# Assert
getattr(provider_configuration, provider_method_name).assert_called_once_with(
model="gpt-4o",
model_type=ModelType.LLM,
)

View File

@@ -316,7 +316,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result == expected_detail
@@ -346,7 +346,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result["name"] == f"App from {mode}"
@@ -369,7 +369,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result is None
@@ -392,7 +392,7 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result == {}
@@ -432,9 +432,197 @@ class TestRecommendedAppServiceGetDetail:
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result["model_config"] == complex_model_config
assert len(result["workflows"]) == 2
assert len(result["tools"]) == 3
# === Merged from test_recommended_app_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from services import recommended_app_service as service_module
from services.recommended_app_service import RecommendedAppService
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]:
return cast(dict[str, Any], result)
@pytest.fixture
def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
# Arrange
session = MagicMock()
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session))
# Assert
return session
def _mock_factory_for_apps(
monkeypatch: pytest.MonkeyPatch,
*,
mode: str,
result: dict[str, Any],
fallback_result: dict[str, Any] | None = None,
) -> tuple[MagicMock, MagicMock]:
retrieval_instance = MagicMock()
retrieval_instance.get_recommended_apps_and_categories.return_value = result
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
builtin_instance = MagicMock()
if fallback_result is not None:
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_buildin_recommend_app_retrieval",
MagicMock(return_value=builtin_instance),
)
return retrieval_instance, builtin_instance
def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
) -> None:
# Arrange
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
retrieval_instance, builtin_instance = _mock_factory_for_apps(
monkeypatch,
mode="remote",
result=expected,
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
)
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == expected
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
mocked_db_session.scalar.assert_not_called()
def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
) -> None:
# Arrange
remote_result = {"recommended_apps": [], "categories": []}
fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]}
_, builtin_instance = _mock_factory_for_apps(
monkeypatch,
mode="remote",
result=remote_result,
fallback_result=fallback_result,
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None]
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
# Assert
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
assert result["recommended_apps"][0]["can_trial"] is True
assert result["recommended_apps"][1]["can_trial"] is False
assert mocked_db_session.scalar.call_count == 2
@pytest.mark.parametrize(
("trial_query_result", "expected_can_trial"),
[
(SimpleNamespace(id="trial"), True),
(None, False),
],
)
def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
trial_query_result: Any,
expected_can_trial: bool,
) -> None:
# Arrange
detail = {"id": "app-1", "name": "Test App"}
retrieval_instance = MagicMock()
retrieval_instance.get_recommend_app_detail.return_value = detail
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
mocked_db_session.scalar.return_value = trial_query_result
# Act
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1"))
# Assert
assert result["id"] == "app-1"
assert result["can_trial"] is expected_can_trial
mocked_db_session.scalar.assert_called_once()
def test_add_trial_app_record_should_increment_count_when_existing_record_found(
mocked_db_session: MagicMock,
) -> None:
# Arrange
existing_record = SimpleNamespace(count=3)
mocked_db_session.scalar.return_value = existing_record
# Act
RecommendedAppService.add_trial_app_record("app-1", "account-1")
# Assert
assert existing_record.count == 4
mocked_db_session.scalar.assert_called_once()
mocked_db_session.commit.assert_called_once()
mocked_db_session.add.assert_not_called()
def test_add_trial_app_record_should_create_new_record_when_no_existing_record(
mocked_db_session: MagicMock,
) -> None:
# Arrange
mocked_db_session.scalar.return_value = None
# Act
RecommendedAppService.add_trial_app_record("app-2", "account-2")
# Assert
mocked_db_session.scalar.assert_called_once()
mocked_db_session.add.assert_called_once()
added = mocked_db_session.add.call_args.args[0]
assert added.app_id == "app-2"
assert added.account_id == "account-2"
assert added.count == 1
mocked_db_session.commit.assert_called_once()

View File

@@ -1,12 +1,15 @@
import unittest
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
from events.event_handlers.sync_workflow_schedule_when_app_published import (
sync_schedule_from_workflow,
)
@@ -14,6 +17,8 @@ from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
from models.account import Account, TenantAccountJoin
from models.trigger import WorkflowSchedulePlan
from models.workflow import Workflow
from services.errors.account import AccountNotFoundError
from services.trigger import schedule_service as service_module
from services.trigger.schedule_service import ScheduleService
@@ -775,5 +780,158 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
mock_session.commit.assert_called_once()
@pytest.fixture
def session_mock() -> MagicMock:
return MagicMock(spec=Session)
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def test_update_schedule_should_update_only_node_id_without_recomputing_time(
session_mock: MagicMock,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
schedule = MagicMock(spec=WorkflowSchedulePlan)
schedule.cron_expression = "0 10 * * *"
schedule.timezone = "UTC"
session_mock.get.return_value = schedule
next_run_mock = MagicMock(return_value=datetime(2026, 1, 1, 10, 0, tzinfo=UTC))
monkeypatch.setattr(service_module, "calculate_next_run_at", next_run_mock)
# Act
result = ScheduleService.update_schedule(
session=session_mock,
schedule_id="schedule-1",
updates=SchedulePlanUpdate(node_id="node-new"),
)
# Assert
assert result is schedule
assert schedule.node_id == "node-new"
next_run_mock.assert_not_called()
session_mock.flush.assert_called_once()
def test_get_tenant_owner_should_raise_when_account_record_missing(session_mock: MagicMock) -> None:
# Arrange
join = SimpleNamespace(account_id="account-404")
session_mock.execute.return_value.scalar_one_or_none.return_value = join
session_mock.get.return_value = None
# Act / Assert
with pytest.raises(AccountNotFoundError, match="Account not found: account-404"):
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
def test_get_tenant_owner_should_raise_when_no_owner_or_admin_found(session_mock: MagicMock) -> None:
# Arrange
session_mock.execute.return_value.scalar_one_or_none.side_effect = [None, None]
# Act / Assert
with pytest.raises(AccountNotFoundError, match="Account not found for tenant: tenant-1"):
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
def test_update_next_run_at_should_raise_when_schedule_not_found(session_mock: MagicMock) -> None:
# Arrange
session_mock.get.return_value = None
# Act / Assert
with pytest.raises(ScheduleNotFoundError, match="Schedule not found: schedule-1"):
ScheduleService.update_next_run_at(session=session_mock, schedule_id="schedule-1")
def test_to_schedule_config_should_build_from_cron_mode() -> None:
# Arrange
node_config: dict[str, Any] = {
"id": "node-1",
"data": {
"mode": "cron",
"cron_expression": "0 12 * * *",
"timezone": "Asia/Kolkata",
},
}
# Act
result = ScheduleService.to_schedule_config(node_config=node_config)
# Assert
assert result.node_id == "node-1"
assert result.cron_expression == "0 12 * * *"
assert result.timezone == "Asia/Kolkata"
def test_to_schedule_config_should_raise_for_cron_mode_without_expression() -> None:
# Arrange
node_config = {"id": "node-1", "data": {"mode": "cron", "cron_expression": ""}}
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Cron expression is required for cron mode"):
ScheduleService.to_schedule_config(node_config=node_config)
def test_to_schedule_config_should_build_from_visual_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
node_config = {
"id": "node-1",
"data": {
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "9:30 AM"},
"timezone": "UTC",
},
}
monkeypatch.setattr(ScheduleService, "visual_to_cron", MagicMock(return_value="30 9 * * *"))
# Act
result = ScheduleService.to_schedule_config(node_config=node_config)
# Assert
assert result.cron_expression == "30 9 * * *"
def test_to_schedule_config_should_raise_for_invalid_mode() -> None:
# Arrange
node_config = {"id": "node-1", "data": {"mode": "manual"}}
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: manual"):
ScheduleService.to_schedule_config(node_config=node_config)
def test_extract_schedule_config_should_raise_when_graph_is_empty() -> None:
# Arrange
workflow = _workflow(graph_dict={})
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
ScheduleService.extract_schedule_config(workflow=workflow)
def test_extract_schedule_config_should_raise_when_mode_invalid() -> None:
# Arrange
workflow = _workflow(
graph_dict={
"nodes": [
{
"id": "schedule-1",
"data": {
"type": TRIGGER_SCHEDULE_NODE_TYPE,
"mode": "invalid",
},
}
]
}
)
# Act / Assert
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: invalid"):
ScheduleService.extract_schedule_config(workflow=workflow)
if __name__ == "__main__":
unittest.main()

View File

@@ -12,6 +12,7 @@ This test suite covers all functionality of the current VariableTruncator includ
import functools
import json
import uuid
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
@@ -199,14 +200,14 @@ class TestArrayTruncation:
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
"""Test that small arrays are not truncated."""
small_array = [1, 2]
small_array: list[object] = [1, 2]
result = small_truncator._truncate_array(small_array, 1000)
assert result.value == small_array
assert result.truncated is False
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
"""Test that arrays over element limit are truncated."""
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
large_array: list[object] = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
result = small_truncator._truncate_array(large_array, 1000)
assert result.truncated is True
@@ -215,7 +216,7 @@ class TestArrayTruncation:
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
"""Test array truncation due to size budget constraints."""
# Create array with strings that will exceed size budget
large_strings = ["very long string " * 5, "another long string " * 5]
large_strings: list[object] = ["very long string " * 5, "another long string " * 5]
result = small_truncator._truncate_array(large_strings, 50)
assert result.truncated is True
@@ -276,10 +277,10 @@ class TestObjectTruncation:
# Values should be truncated if they exist
for key, value in result.value.items():
if isinstance(value, str):
original_value = obj_with_long_values[key]
# Value should be same or smaller
assert len(value) <= len(original_value)
assert isinstance(value, str)
original_value = obj_with_long_values[key]
# Value should be same or smaller
assert len(value) <= len(original_value)
def test_object_key_dropping(self, small_truncator):
"""Test object truncation where keys are dropped due to size constraints."""
@@ -506,10 +507,9 @@ class TestEdgeCases:
truncator = VariableTruncator(string_length_limit=10)
# Unicode characters
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
unicode_text = "你好世界你好世界你好世界" # Multi-byte UTF-8 characters
result = truncator.truncate(StringSegment(value=unicode_text))
if len(unicode_text) > 10:
assert result.truncated is True
assert result.truncated is True
# Special JSON characters
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
@@ -631,13 +631,12 @@ class TestIntegrationScenarios:
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
# Should handle all data types appropriately
if result.truncated:
# Verify the result is smaller or equal than original
original_size = truncator.calculate_json_size(mixed_data)
if isinstance(result.result, ObjectSegment):
result_size = truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
assert result.truncated is True
assert isinstance(result.result, ObjectSegment)
# Verify the result is smaller or equal than original
original_size = truncator.calculate_json_size(mixed_data)
result_size = truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
def test_file_and_array_file_variable_mapping(self, file):
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
@@ -675,3 +674,229 @@ def test_dummy_variable_truncator_methods():
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False
# === Merged from test_variable_truncator_additional.py ===
from typing import Any
import pytest
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
from graphon.variables.types import SegmentType
from services import variable_truncator as truncator_module
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
class _AbstractPassthrough(BaseTruncator):
def truncate(self, segment: Any) -> TruncationResult:
# Arrange / Act
return super().truncate(segment) # type: ignore[misc]
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
# Arrange / Act
return super().truncate_variable_mapping(v) # type: ignore[misc]
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
# Arrange
passthrough = _AbstractPassthrough()
# Act
truncate_result = passthrough.truncate(StringSegment(value="x"))
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
# Assert
assert truncate_result is None
assert mapping_result is None
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
# Act
truncator = VariableTruncator.default()
# Assert
assert truncator._max_size_bytes == 111
assert truncator._array_element_limit == 7
assert truncator._string_length_limit == 33
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=5)
mapping = {"very_long_key": "value"}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert result == {"very_long_key": "..."}
assert truncated is True
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
mapping = {"seg": StringSegment(value="hello")}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert isinstance(result["seg"], StringSegment)
assert result["seg"].value == "hello"
assert truncated is False
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
(True, False),
(1, False),
(1.5, False),
("x", True),
({"k": "v"}, True),
],
)
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
# Arrange
# Act
result = VariableTruncator._json_value_needs_truncation(value)
# Assert
assert result is expected
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_result = truncator_module._PartResult(
value=StringSegment(value="this is too long"),
value_size=100,
truncated=True,
)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(StringSegment(value="input"))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert not result.result.value.startswith('"')
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator()
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_segment(IntegerSegment(value=1), 10)
def test_calculate_json_size_should_unwrap_segment_values() -> None:
# Arrange
segment = StringSegment(value="abc")
# Act
size = VariableTruncator.calculate_json_size(segment)
# Assert
assert size == VariableTruncator.calculate_json_size("abc")
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
# Arrange
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
size = VariableTruncator.calculate_json_size(updated)
# Assert
assert size > 0
def test_maybe_qa_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
assert VariableTruncator._maybe_qa_structure({}) is False
def test_maybe_parent_child_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
)
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
# Arrange
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
mapping = {"s": StringSegment(value="long-content")}
# Act
result = truncator._truncate_object(mapping, 20)
# Assert
assert result.truncated is True
assert isinstance(result.value["s"], StringSegment)
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
result = truncator._truncate_json_primitives(updated, 100)
# Assert
assert isinstance(result.value, dict)
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
# Arrange
truncator = VariableTruncator()
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_segment = ObjectSegment(value={"k": "v"})
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)

View File

@@ -559,3 +559,757 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# === Merged from test_webhook_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.query.return_value = _FakeQuery(None)
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1", is_debug=True
)
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
# Assert
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
# Act / Assert
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
# Arrange
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = MagicMock()
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
# Assert
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
# Act
result = WebhookService._detect_binary_mimetype(b"binary")
# Assert
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
# Act
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
# Assert
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
# Act
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
# Assert
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
raw_value: str,
param_type: str,
expected: Any,
) -> None:
# Arrange
# Act
result = WebhookService._convert_form_value("param", raw_value, param_type)
# Assert
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
# Act
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
# Assert
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
# Arrange
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters() -> None:
# Arrange
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
# Act
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
# Assert
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
# Arrange
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
# Act
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
# Assert
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
# Arrange
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act
WebhookService._validate_required_headers(headers, configs)
# Assert
assert True
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
# Arrange
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
# Arrange
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
# Act
result = WebhookService._validate_http_metadata(webhook_data, node_data)
# Assert
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
# Arrange
headers = {"content-type": "application/json; charset=utf-8"}
# Act
result = WebhookService._extract_content_type(headers)
# Assert
assert result == "application/json"
def test_build_workflow_inputs_should_include_expected_keys() -> None:
# Arrange
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
# Act
result = WebhookService.build_workflow_inputs(webhook_data)
# Assert
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
# Act
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Assert
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
# Act / Assert
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
# Act / Assert
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
# Act / Assert
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
# Act / Assert
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
assert fake_session.commit_count == 2
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert logger_exception_mock.call_count == 1
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
# Arrange
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
# Act
body, status = WebhookService.generate_webhook_response(node_config)
# Assert
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
# Arrange
# Act
webhook_id = WebhookService.generate_webhook_id()
# Assert
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
# Arrange
# Act
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
# Assert
assert result == 123

View File

@@ -176,3 +176,300 @@ class TestWorkflowRunService:
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
# === Merged from test_workflow_run_service.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
from services import workflow_run_service as service_module
from services.workflow_run_service import WorkflowRunService
@pytest.fixture
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
# Arrange
node_repo = MagicMock()
workflow_run_repo = MagicMock()
factory = SimpleNamespace(
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
# Assert
return node_repo, workflow_run_repo, factory
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
# Act
service = WorkflowRunService()
# Assert
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_create_sessionmaker_when_engine_is_provided(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
class FakeEngine:
pass
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "Engine", FakeEngine)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
engine = cast(Engine, FakeEngine())
# Act
service = WorkflowRunService(session_factory=engine)
# Assert
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
node_repo, workflow_run_repo, factory = repository_factory_mocks
session_factory = MagicMock(name="session_factory")
# Act
service = WorkflowRunService(session_factory=session_factory)
# Assert
assert service._session_factory is session_factory
assert service._node_execution_service_repo is node_repo
assert service._workflow_run_repo is workflow_run_repo
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="pagination")
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
# Act
result = service.get_paginate_workflow_runs(
app_model=app_model,
args=args,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result is expected
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
limit=7,
last_id="last-1",
status="succeeded",
)
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
# Act
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
# Assert
assert result is pagination
assert len(result.data) == 2
assert result.data[0].message_id == "msg-1"
assert result.data[0].conversation_id == "conv-1"
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="workflow_run")
workflow_run_repo.get_workflow_run_by_id.return_value = expected
# Act
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
# Assert
assert result is expected
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
run_id="run-1",
)
def test_get_workflow_runs_count_should_forward_optional_filters(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = {"total": 3, "succeeded": 2}
workflow_run_repo.get_workflow_runs_count.return_value = expected
# Act
result = service.get_workflow_runs_count(
app_model=app_model,
status="succeeded",
time_range="7d",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result == expected
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
status="succeeded",
time_range="7d",
)
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-1")
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == []
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
class FakeEndUser:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
app_model = _app_model(id="app-1")
expected = [SimpleNamespace(id="exec-1")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-end-user",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-account")
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-account",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id=None)
# Act / Assert
with pytest.raises(ValueError, match="tenant_id cannot be None"):
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)

View File

@@ -0,0 +1,831 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
)
from core.helper import encrypter
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import Account, App, AppMode, AppModelConfig
from services.workflow import workflow_converter as converter_module
from services.workflow.workflow_converter import WorkflowConverter
try:
from graphon.enums import BuiltinNodeTypes
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.message_entities import PromptMessageRole
from graphon.variables.input_entities import VariableEntity, VariableEntityType
except ModuleNotFoundError:
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
@pytest.fixture
def converter() -> WorkflowConverter:
return WorkflowConverter()
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _app_model_config(**kwargs: Any) -> AppModelConfig:
return cast(AppModelConfig, SimpleNamespace(**kwargs))
def _build_start_graph() -> dict[str, Any]:
return {
"nodes": [
{
"id": "start",
"position": None,
"data": {"type": BuiltinNodeTypes.START, "variables": [{"variable": "name"}, {"variable": "city"}]},
}
],
"edges": [],
}
def _build_model_config(mode: str | LLMMode) -> ModelConfigEntity:
return ModelConfigEntity(provider="openai", model="gpt-4", mode=mode, parameters={}, stop=[])
@pytest.fixture
def default_variables() -> list[VariableEntity]:
return [
VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT),
VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH),
VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT),
]
def test__convert_to_start_node(default_variables: list[VariableEntity]) -> None:
result = WorkflowConverter()._convert_to_start_node(default_variables)
assert result["id"] == "start"
assert result["data"]["type"] == BuiltinNodeTypes.START
assert result["data"]["variables"][0]["type"] == "text-input"
assert result["data"]["variables"][0]["variable"] == "text_input"
def test__convert_to_http_request_node_for_chatbot(default_variables: list[VariableEntity]) -> None:
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT
extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
extension.id = "api_based_extension_id"
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable",
type="api",
config={"api_based_extension_id": "api_based_extension_id"},
),
]
nodes, mapping = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables,
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == BuiltinNodeTypes.HTTP_REQUEST
assert nodes[1]["data"]["type"] == BuiltinNodeTypes.CODE
body = json.loads(nodes[0]["data"]["body"]["data"])
assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
assert body["params"]["query"] == "{{#sys.query#}}"
assert body["params"]["inputs"]["text_input"] == "{{#start.text_input#}}"
assert mapping == {"external_variable": "code_1"}
def test__convert_to_http_request_node_for_workflow_app(default_variables: list[VariableEntity]) -> None:
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW
extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
extension.id = "api_based_extension_id"
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable",
type="api",
config={"api_based_extension_id": "api_based_extension_id"},
),
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables,
)
body = json.loads(nodes[0]["data"]["body"]["data"])
assert body["params"]["query"] == ""
def test__convert_to_knowledge_retrieval_node_for_chatbot() -> None:
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.ADVANCED_CHAT,
dataset_config=dataset_config,
model_config=model_config,
)
assert node is not None
assert node["data"]["query_variable_selector"] == ["sys", "query"]
assert node["data"]["multiple_retrieval_config"]["top_k"] == 5
def test__convert_to_knowledge_retrieval_node_for_workflow_app() -> None:
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable="query",
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.WORKFLOW,
dataset_config=dataset_config,
model_config=model_config,
)
assert node is not None
assert node["data"]["query_variable_selector"] == ["start", "query"]
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables: list[VariableEntity]) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helper for {{text_input}} and {{paragraph}}",
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["type"] == BuiltinNodeTypes.LLM
assert node["data"]["memory"] is not None
assert node["data"]["prompt_template"][0]["role"] == "user"
assert "{{#start.text_input#}}" in node["data"]["prompt_template"][0]["text"]
def test__convert_to_llm_node_for_chatbot_simple_chat_model_with_empty_template(
default_variables: list[VariableEntity],
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="ignored",
)
monkeypatch.setattr(
converter_module.SimplePromptTransform,
"get_prompt_template",
lambda self, **kwargs: {"prompt_template": PromptTemplateParser(""), "prompt_rules": {}},
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["prompt_template"] == []
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables: list[VariableEntity]) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[AdvancedChatMessageEntity(text="Hello {{text_input}}", role=PromptMessageRole.USER)]
),
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert isinstance(node["data"]["prompt_template"], list)
assert node["data"]["prompt_template"][0]["role"] == PromptMessageRole.USER.value
def test__convert_to_llm_node_for_chatbot_advanced_chat_model_without_template(
default_variables: list[VariableEntity],
) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=None,
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.WORKFLOW,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["prompt_template"] == []
assert node["data"]["memory"] is None
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables: list[VariableEntity]) -> None:
workflow_converter = WorkflowConverter()
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
model_config = ModelConfigEntity(
provider="openai",
model="gpt-3.5-turbo-instruct",
mode=LLMMode.COMPLETION.value,
parameters={},
stop=[],
)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="Hello {{text_input}} and {{#query#}}",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
),
)
node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.ADVANCED_CHAT,
model_config=model_config,
graph=graph,
prompt_template=prompt_template,
)
assert node["data"]["prompt_template"]["text"].find("{{#sys.query#}}") != -1
assert node["data"]["memory"]["role_prefix"]["user"] == "Human"
def test__convert_to_end_node() -> None:
node = WorkflowConverter()._convert_to_end_node()
assert node["id"] == "end"
assert node["data"]["type"] == BuiltinNodeTypes.END
def test__convert_to_answer_node() -> None:
node = WorkflowConverter()._convert_to_answer_node()
assert node["id"] == "answer"
assert node["data"]["type"] == BuiltinNodeTypes.ANSWER
def test_convert_to_workflow_should_raise_when_app_model_config_is_missing(converter: WorkflowConverter) -> None:
app_model = _app_model(app_model_config=None)
with pytest.raises(ValueError, match="App model config is required"):
converter.convert_to_workflow(
app_model=app_model,
account=_account(id="account-1"),
name="new-app",
icon_type="emoji",
icon="robot",
icon_background="#fff",
)
@pytest.mark.parametrize(
("source_mode", "expected_mode"),
[
(AppMode.CHAT, AppMode.ADVANCED_CHAT),
(AppMode.COMPLETION, AppMode.WORKFLOW),
],
)
def test_convert_to_workflow_should_create_new_app_with_fallback_fields(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
source_mode: AppMode,
expected_mode: AppMode,
) -> None:
class FakeApp:
def __init__(self) -> None:
self.id = "new-app-id"
workflow = SimpleNamespace(app_id=None)
monkeypatch.setattr(converter, "convert_app_model_config_to_workflow", MagicMock(return_value=workflow))
monkeypatch.setattr(converter_module, "App", FakeApp)
db_session = SimpleNamespace(add=MagicMock(), flush=MagicMock(), commit=MagicMock())
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
send_mock = MagicMock()
monkeypatch.setattr(converter_module.app_was_created, "send", send_mock)
account = _account(id="account-1")
app_model = _app_model(
tenant_id="tenant-1",
name="Source App",
mode=source_mode,
icon_type="emoji",
icon="sparkles",
icon_background="#123456",
enable_site=True,
enable_api=True,
api_rpm=10,
api_rph=100,
is_public=False,
app_model_config=_app_model_config(id="config-1"),
)
new_app = converter.convert_to_workflow(
app_model=app_model,
account=account,
name="",
icon_type="",
icon="",
icon_background="",
)
assert new_app.name == "Source App(workflow)"
assert new_app.mode == expected_mode
assert new_app.icon_type == "emoji"
assert new_app.icon == "sparkles"
assert new_app.icon_background == "#123456"
assert new_app.created_by == "account-1"
assert workflow.app_id == "new-app-id"
db_session.add.assert_called_once()
db_session.flush.assert_called_once()
db_session.commit.assert_called_once()
send_mock.assert_called_once_with(new_app, account=account)
def test_convert_app_model_config_to_workflow_should_build_advanced_chat_graph_and_features(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
app_config = SimpleNamespace(
variables=[SimpleNamespace(variable="name")],
external_data_variables=[SimpleNamespace(variable="ext")],
dataset=SimpleNamespace(id="dataset"),
model=SimpleNamespace(),
prompt_template=SimpleNamespace(),
additional_features=SimpleNamespace(file_upload=SimpleNamespace()),
app_model_config_dict={
"opening_statement": "hello",
"suggested_questions": ["q1"],
"suggested_questions_after_answer": True,
"speech_to_text": True,
"text_to_speech": {"enabled": True},
"file_upload": {"enabled": True},
"sensitive_word_avoidance": {"enabled": True},
"retriever_resource": {"enabled": True},
},
)
class FakeWorkflow:
VERSION_DRAFT = "draft"
def __init__(self, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.ADVANCED_CHAT))
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
monkeypatch.setattr(
converter,
"_convert_to_start_node",
MagicMock(
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
),
)
monkeypatch.setattr(
converter,
"_convert_to_http_request_node",
MagicMock(
return_value=(
[{"id": "http", "position": None, "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}],
{"ext": "code_1"},
)
),
)
monkeypatch.setattr(
converter,
"_convert_to_knowledge_retrieval_node",
MagicMock(
return_value={"id": "knowledge", "position": None, "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}
),
)
monkeypatch.setattr(
converter,
"_convert_to_llm_node",
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
)
monkeypatch.setattr(
converter,
"_convert_to_answer_node",
MagicMock(return_value={"id": "answer", "position": None, "data": {"type": BuiltinNodeTypes.ANSWER}}),
)
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
workflow = converter.convert_app_model_config_to_workflow(
app_model=app_model,
app_model_config=_app_model_config(id="cfg"),
account_id="account-1",
)
graph = json.loads(workflow.graph)
node_ids = [node["id"] for node in graph["nodes"]]
assert node_ids == ["start", "http", "knowledge", "llm", "answer"]
features = json.loads(workflow.features)
assert "opening_statement" in features
assert "retriever_resource" in features
db_session.add.assert_called_once()
db_session.commit.assert_called_once()
def test_convert_app_model_config_to_workflow_should_build_workflow_mode_with_end_node(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.COMPLETION)
app_config = SimpleNamespace(
variables=[SimpleNamespace(variable="name")],
external_data_variables=[],
dataset=SimpleNamespace(id="dataset"),
model=SimpleNamespace(),
prompt_template=SimpleNamespace(),
additional_features=None,
app_model_config_dict={
"text_to_speech": {"enabled": False},
"file_upload": {"enabled": False},
"sensitive_word_avoidance": {"enabled": False},
},
)
class FakeWorkflow:
VERSION_DRAFT = "draft"
def __init__(self, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.WORKFLOW))
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
monkeypatch.setattr(
converter,
"_convert_to_start_node",
MagicMock(
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
),
)
monkeypatch.setattr(converter, "_convert_to_knowledge_retrieval_node", MagicMock(return_value=None))
monkeypatch.setattr(
converter,
"_convert_to_llm_node",
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
)
monkeypatch.setattr(
converter,
"_convert_to_end_node",
MagicMock(return_value={"id": "end", "position": None, "data": {"type": BuiltinNodeTypes.END}}),
)
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
workflow = converter.convert_app_model_config_to_workflow(
app_model=app_model,
app_model_config=_app_model_config(id="cfg"),
account_id="account-1",
)
graph = json.loads(workflow.graph)
node_ids = [node["id"] for node in graph["nodes"]]
assert node_ids == ["start", "llm", "end"]
features = json.loads(workflow.features)
assert set(features.keys()) == {"text_to_speech", "file_upload", "sensitive_word_avoidance"}
def test_convert_to_app_config_should_route_to_correct_manager(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent_result = SimpleNamespace(kind="agent")
chat_result = SimpleNamespace(kind="chat")
completion_result = SimpleNamespace(kind="completion")
monkeypatch.setattr(
converter_module.AgentChatAppConfigManager, "get_app_config", MagicMock(return_value=agent_result)
)
monkeypatch.setattr(converter_module.ChatAppConfigManager, "get_app_config", MagicMock(return_value=chat_result))
monkeypatch.setattr(
converter_module.CompletionAppConfigManager,
"get_app_config",
MagicMock(return_value=completion_result),
)
from_agent_mode = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.AGENT_CHAT, is_agent=False),
app_model_config=_app_model_config(id="cfg-1"),
)
from_agent_flag = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.CHAT, is_agent=True),
app_model_config=_app_model_config(id="cfg-2"),
)
from_chat_mode = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.CHAT, is_agent=False),
app_model_config=_app_model_config(id="cfg-3"),
)
from_completion_mode = converter._convert_to_app_config(
app_model=_app_model(mode=AppMode.COMPLETION, is_agent=False),
app_model_config=_app_model_config(id="cfg-4"),
)
assert from_agent_mode is agent_result
assert from_agent_flag is agent_result
assert from_chat_mode is chat_result
assert from_completion_mode is completion_result
def test_convert_to_app_config_should_raise_for_invalid_app_mode(converter: WorkflowConverter) -> None:
app_model = _app_model(mode=AppMode.WORKFLOW, is_agent=False)
with pytest.raises(ValueError, match="Invalid app mode"):
converter._convert_to_app_config(app_model=app_model, app_model_config=_app_model_config(id="cfg"))
def test_convert_to_http_request_node_should_skip_non_api_and_missing_extension_id(
converter: WorkflowConverter,
) -> None:
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
external_data_variables = [
ExternalDataVariableEntity(variable="skip_type", type="dataset", config={"api_based_extension_id": "x"}),
ExternalDataVariableEntity(variable="skip_config", type="api", config={}),
]
nodes, mapping = converter._convert_to_http_request_node(
app_model=app_model,
variables=[],
external_data_variables=external_data_variables,
)
assert nodes == []
assert mapping == {}
def test_convert_to_knowledge_retrieval_node_should_return_none_for_workflow_without_query_variable(
converter: WorkflowConverter,
) -> None:
dataset_config = DatasetEntity(
dataset_ids=["ds-1"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=None,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
),
)
model_config = _build_model_config(mode=LLMMode.CHAT)
node = converter._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.WORKFLOW,
dataset_config=dataset_config,
model_config=model_config,
)
assert node is None
def test_convert_to_llm_node_should_raise_when_simple_chat_template_missing(
converter: WorkflowConverter,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.CHAT)
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
with pytest.raises(ValueError, match="Simple prompt template is required"):
converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_raise_when_prompt_template_parser_type_is_invalid_for_chat(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.CHAT)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="Hello {{name}}",
)
monkeypatch.setattr(
converter_module.SimplePromptTransform,
"get_prompt_template",
lambda self, **kwargs: {"prompt_template": "invalid"},
)
with pytest.raises(TypeError, match="Expected PromptTemplateParser"):
converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=AppMode.ADVANCED_CHAT,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_raise_when_simple_completion_template_missing(
converter: WorkflowConverter,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.COMPLETION)
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
with pytest.raises(ValueError, match="Simple prompt template is required"):
converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.WORKFLOW,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_raise_when_completion_prompt_rules_type_is_invalid(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.COMPLETION)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="Hello {{name}}",
)
monkeypatch.setattr(
converter_module.SimplePromptTransform,
"get_prompt_template",
lambda self, **kwargs: {"prompt_template": PromptTemplateParser("Hello {{name}}"), "prompt_rules": "invalid"},
)
with pytest.raises(TypeError, match="Expected dict for prompt_rules"):
converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.ADVANCED_CHAT,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
def test_convert_to_llm_node_should_use_empty_text_for_advanced_completion_without_template(
converter: WorkflowConverter,
) -> None:
graph = _build_start_graph()
model_config = _build_model_config(mode=LLMMode.COMPLETION)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=None,
)
llm_node = converter._convert_to_llm_node(
original_app_mode=AppMode.COMPLETION,
new_app_mode=AppMode.WORKFLOW,
graph=graph,
model_config=model_config,
prompt_template=prompt_template,
)
assert llm_node["data"]["prompt_template"]["text"] == ""
assert llm_node["data"]["memory"] is None
def test_replace_template_variables_should_replace_start_and_external_references(converter: WorkflowConverter) -> None:
template = "Hello {{name}} from {{city}} with {{weather}}"
variables = [{"variable": "name"}, {"variable": "city"}]
external_mapping = {"weather": "code_1"}
result = converter._replace_template_variables(template, variables, external_mapping)
assert result == "Hello {{#start.name#}} from {{#start.city#}} with {{#code_1.result#}}"
def test_graph_helpers_should_create_edges_append_nodes_and_choose_mode(converter: WorkflowConverter) -> None:
graph = {"nodes": [{"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START}}], "edges": []}
node = {"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}
edge = converter._create_edge("start", "llm")
updated_graph = converter._append_node(graph, node)
workflow_mode = converter._get_new_app_mode(_app_model(mode=AppMode.COMPLETION))
advanced_chat_mode = converter._get_new_app_mode(_app_model(mode=AppMode.CHAT))
assert edge == {"id": "start-llm", "source": "start", "target": "llm"}
assert updated_graph["nodes"][-1]["id"] == "llm"
assert updated_graph["edges"][-1]["source"] == "start"
assert workflow_mode == AppMode.WORKFLOW
assert advanced_chat_mode == AppMode.ADVANCED_CHAT
def test_get_api_based_extension_should_raise_when_extension_not_found(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
db_session = SimpleNamespace(scalar=MagicMock(return_value=None))
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
with pytest.raises(ValueError, match="API Based Extension not found"):
converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
db_session.scalar.assert_called_once()
def test_get_api_based_extension_should_return_entity_when_found(
converter: WorkflowConverter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
extension = SimpleNamespace(id="ext-1")
db_session = SimpleNamespace(scalar=MagicMock(return_value=extension))
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
result = converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
assert result is extension
db_session.scalar.assert_called_once()

View File

@@ -1,10 +1,9 @@
from __future__ import annotations
import json
import queue
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from itertools import cycle
from threading import Event
import pytest
@@ -224,3 +223,577 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected
# === Merged from test_workflow_event_snapshot_service_additional.py ===
import json
import queue
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"query": "hello"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=1.2,
total_tokens=5,
total_steps=2,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.outputs = {"answer": "ok"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionMaker:
def __init__(self, session: Any) -> None:
self._session = session
def __call__(self) -> _SessionContext:
return _SessionContext(self._session)
class _SubscriptionContext:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def __enter__(self) -> Any:
return self._subscription
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _Topic:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def subscribe(self) -> _SubscriptionContext:
return _SubscriptionContext(self._subscription)
class _StaticSubscription:
def receive(self, timeout: int = 1) -> None:
return None
@dataclass(frozen=True)
class _PauseEntity(WorkflowPauseEntity):
state: bytes
@property
def id(self) -> str:
return "pause-1"
@property
def workflow_execution_id(self) -> str:
return "run-1"
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return datetime(2024, 1, 1, tzinfo=UTC)
def get_state(self) -> bytes:
return self.state
def get_pause_reasons(self) -> list[Any]:
return []
def test_get_message_context_should_return_none_when_no_message() -> None:
# Arrange
session = SimpleNamespace(scalar=MagicMock(return_value=None))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is None
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
# Arrange
message = SimpleNamespace(
id="msg-1",
conversation_id="conv-1",
created_at=None,
answer="answer",
)
session = SimpleNamespace(scalar=MagicMock(return_value=message))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is not None
assert result.created_at == 0
assert result.message_id == "msg-1"
assert result.conversation_id == "conv-1"
assert result.answer == "answer"
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
# Arrange
# Act
result = service_module._load_resumption_context(None)
# Assert
assert result is None
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
# Arrange
pause_entity = _PauseEntity(state=b"not-a-valid-state")
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is None
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
# Arrange
context = _build_resumption_context_additional(task_id="task-ctx")
pause_entity = _PauseEntity(state=context.dumps().encode())
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is not None
assert result.get_generate_entity().task_id == "task-ctx"
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
# Arrange
# Act
result = service_module._resolve_task_id(
resumption_context=None,
buffer_state=None,
workflow_run_id="run-1",
)
# Assert
assert result == "run-1"
@pytest.mark.parametrize(
("payload", "expected"),
[
(b'{"event":"node_started"}', {"event": "node_started"}),
(b"invalid-json", None),
(b"[]", None),
],
)
def test_parse_event_message_should_parse_only_json_object(
payload: bytes,
expected: dict[str, Any] | None,
) -> None:
# Arrange
# Act
result = service_module._parse_event_message(payload)
# Assert
assert result == expected
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
# Arrange
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
# Act
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
# Assert
assert is_finished is True
assert paused_without_flag is False
assert paused_with_flag is True
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
# Arrange
payload: dict[str, Any] = {"event": "workflow_started"}
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
# Act
service_module._apply_message_context(payload, context)
# Assert
assert payload["conversation_id"] == "conv-1"
assert payload["message_id"] == "msg-1"
assert payload["created_at"] == 1700000000
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
# Arrange
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-1"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
event = buffer_state.queue.get(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert buffer_state.task_id_hint == "task-1"
assert event["event"] == "node_started"
def test_start_buffering_should_drop_old_event_when_queue_is_full(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
class QueueWithSingleFull:
def __init__(self) -> None:
self._first_put = True
self.items: list[dict[str, Any]] = [{"event": "old"}]
def put_nowait(self, item: dict[str, Any]) -> None:
if self._first_put:
self._first_put = False
raise queue.Full
self.items.append(item)
def get_nowait(self) -> dict[str, Any]:
if not self.items:
raise queue.Empty
return self.items.pop(0)
def empty(self) -> bool:
return len(self.items) == 0
fake_queue = QueueWithSingleFull()
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-2"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert fake_queue.items[-1]["task_id"] == "task-2"
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
# Arrange
class Subscription:
def receive(self, timeout: int = 1) -> bytes | None:
raise RuntimeError("subscription failure")
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert finished is True
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(
service_module,
"_get_message_context",
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
)
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
monkeypatch.setattr(
service_module,
"_build_snapshot_events",
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
)
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.ADVANCED_CHAT,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
finished_event = cast(Mapping[str, Any], events[1])
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
assert buffer_state.stop_event.is_set() is True
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
assert called_kwargs["workflow_run_id"] == "run-1"
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
class AlwaysEmptyQueue:
def empty(self) -> bool:
return False
def get(self, timeout: int = 1) -> None:
raise queue.Empty
buffer_state = BufferState(
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
time_values = cycle([0.0, 6.0, 21.0, 26.0])
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
idle_timeout=20.0,
ping_interval=5.0,
)
)
# Assert
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
buffer_state.done_event.set()
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events == [StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
assert snapshot_builder.call_args.kwargs["pause_entity"] is None

View File

@@ -10,6 +10,8 @@ This module tests the document indexing task functionality including:
"""
import uuid
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
@@ -1113,13 +1115,17 @@ class TestAdvancedScenarios:
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task)
# Assert
# Verify delete was called to clean up task key
mock_redis.delete.assert_called_once()
expected_task_key = f"tenant_document_indexing_task:{tenant_id}"
# Verify the correct key was deleted (contains tenant_id and "document_indexing")
delete_call_args = mock_redis.delete.call_args[0][0]
assert tenant_id in delete_call_args
assert "document_indexing" in delete_call_args
# Verify the task key for this tenant was deleted (do not assert call count; fixtures may be shared).
mock_redis.delete.assert_any_call(expected_task_key)
deleted_keys = [delete_call.args[0] for delete_call in mock_redis.delete.call_args_list if delete_call.args]
assert expected_task_key in deleted_keys
deleted_task_key = next(key for key in deleted_keys if key == expected_task_key)
assert tenant_id in deleted_task_key
assert "document_indexing" in deleted_task_key
def test_billing_disabled_skips_limit_checks(
self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service
@@ -1510,3 +1516,475 @@ class TestRobustness:
# Verify the exception message
assert "Feature service" in str(exc_info.value) or isinstance(exc_info.value, Exception)
class _SessionContext:
def __init__(self, session: MagicMock) -> None:
self._session = session
def __enter__(self) -> MagicMock:
return self._session
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
return None
class TestDocumentIndexingTaskSummaryFlow:
"""Additional coverage for summary and tenant queue branches."""
def test_should_return_when_dataset_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test early return when dataset does not exist."""
# Arrange
session = MagicMock()
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = None
session.query.side_effect = lambda model: dataset_query
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
features_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.FeatureService.get_features", features_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
features_mock.assert_not_called()
def test_should_mark_documents_error_when_batch_upload_limit_exceeded(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test batch upload limit triggers error handling."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.first.return_value = document
session = MagicMock()
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(return_value=_SessionContext(session)),
)
features = SimpleNamespace(
billing=SimpleNamespace(
enabled=True,
subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL),
),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", "1")
# Act
_document_indexing("dataset-1", ["doc-1", "doc-2"])
# Assert
assert document.indexing_status == "error"
assert "batch upload limit" in document.error
session.commit.assert_called_once()
def test_should_queue_summary_generation_for_completed_documents(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation is queued for eligible documents."""
# Arrange
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
doc_eligible = SimpleNamespace(
id="doc-1",
indexing_status="completed",
doc_form="text",
need_summary=True,
)
doc_skip_form = SimpleNamespace(
id="doc-2",
indexing_status="completed",
doc_form="qa_model",
need_summary=True,
)
doc_skip_status = SimpleNamespace(
id="doc-3",
indexing_status="processing",
doc_form="text",
need_summary=True,
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")]
phase1_document_query = MagicMock()
phase1_document_query.where.return_value = phase1_document_query
phase1_document_query.all.return_value = phase1_docs
summary_document_query = MagicMock()
summary_document_query.where.return_value = summary_document_query
summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_document_query
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
create_session_mock = MagicMock(
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
)
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
indexing_runner = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1", "doc-2", "doc-3"])
# Assert
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
def test_should_continue_when_summary_queue_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary queueing errors are swallowed."""
# Arrange
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
doc_eligible = SimpleNamespace(
id="doc-1",
indexing_status="completed",
doc_form="text",
need_summary=True,
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [doc_eligible]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
indexing_runner = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
delay_mock = MagicMock(side_effect=Exception("boom"))
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
def test_should_return_when_dataset_missing_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test early return when dataset is missing after indexing."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.side_effect = [dataset, None]
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
session3.query.assert_called()
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation skipped when indexing_technique is not high_quality."""
# Arrange
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="economy",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_should_skip_summary_generation_when_indexing_paused(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation is skipped when indexing is paused."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
runner = MagicMock()
runner.run.side_effect = DocumentIsPausedError("paused")
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_should_handle_indexing_runner_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test generic indexing runner exception is handled."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
runner = MagicMock()
runner.run.side_effect = RuntimeError("boom")
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_should_log_missing_document_entry_in_summary_list(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test falsey document entries are handled in summary iteration."""
# Arrange
class _FalseyDocument:
def __init__(self, doc_id: str) -> None:
self.id = doc_id
def __bool__(self) -> bool:
return False
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False),
vector_space=SimpleNamespace(limit=0, size=0),
)
monkeypatch.setattr(
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
)
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
delay_mock = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
# Act
_document_indexing("dataset-1", ["doc-1"])
# Assert
delay_mock.assert_not_called()
def test_normal_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test normal indexing task delegates to tenant queue handler."""
# Arrange
handler = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
# Act
normal_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
# Assert
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], normal_document_indexing_task)
def test_priority_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test priority indexing task delegates to tenant queue handler."""
# Arrange
handler = MagicMock()
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
# Act
priority_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
# Assert
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], priority_document_indexing_task)