mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 05:09:19 +08:00
test: add unit tests for services and tasks part-4 (#33223)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com> Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com> Co-authored-by: Dev Sharma <50591491+cryptus-neoxys@users.noreply.github.com> Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
This commit is contained in:
@@ -129,6 +129,7 @@ class VariableTruncator(BaseTruncator):
|
||||
used_size += self.calculate_json_size(key)
|
||||
if used_size > budget:
|
||||
truncated_mapping[key] = "..."
|
||||
is_truncated = True
|
||||
continue
|
||||
value_budget = (budget - used_size) // (length - len(truncated_mapping))
|
||||
if isinstance(value, Segment):
|
||||
@@ -164,9 +165,9 @@ class VariableTruncator(BaseTruncator):
|
||||
result = self._truncate_segment(segment, self._max_size_bytes)
|
||||
|
||||
if result.value_size > self._max_size_bytes:
|
||||
if isinstance(result.value, str):
|
||||
result = self._truncate_string(result.value, self._max_size_bytes)
|
||||
return TruncationResult(StringSegment(value=result.value), True)
|
||||
if isinstance(result.value, StringSegment):
|
||||
fallback_result = self._truncate_string(result.value.value, self._max_size_bytes)
|
||||
return TruncationResult(StringSegment(value=fallback_result.value), True)
|
||||
|
||||
# Apply final fallback - convert to JSON string and truncate
|
||||
json_str = dumps_with_segments(result.value, ensure_ascii=False)
|
||||
|
||||
@@ -85,3 +85,644 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
|
||||
assert len(custom_models) == 1
|
||||
# The sanitizer should drop credentials in list response
|
||||
assert custom_models[0].credentials is None
|
||||
|
||||
|
||||
# === Merged from test_model_provider_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from models.provider import ProviderType
|
||||
from services import model_provider_service as service_module
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
|
||||
manager = MagicMock()
|
||||
service = ModelProviderService()
|
||||
service._get_provider_manager = MagicMock(return_value=manager)
|
||||
return service, manager
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
provider_name: str = "openai",
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
custom_models: list[Any] | None = None,
|
||||
custom_config_available: bool = True,
|
||||
) -> SimpleNamespace:
|
||||
if supported_model_types is None:
|
||||
supported_model_types = [ModelType.LLM]
|
||||
return SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
provider=provider_name,
|
||||
label=I18nObject(en_US=provider_name),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=supported_model_types,
|
||||
configurate_methods=[],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="Credential 1",
|
||||
available_credentials=[],
|
||||
),
|
||||
models=custom_models,
|
||||
can_added_models=[],
|
||||
),
|
||||
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
|
||||
is_custom_configuration_available=lambda: custom_config_available,
|
||||
)
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
provider_configuration = SimpleNamespace(name="provider-config")
|
||||
manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act
|
||||
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert result is provider_configuration
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_configurations.return_value = {}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ProviderNotFoundError, match="does not exist"):
|
||||
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
|
||||
|
||||
|
||||
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
allowed = _build_provider_configuration(
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
custom_config_available=False,
|
||||
)
|
||||
filtered = _build_provider_configuration(
|
||||
provider_name="embedding",
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
custom_config_available=True,
|
||||
)
|
||||
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
|
||||
|
||||
# Act
|
||||
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert result[0].custom_configuration.status.value == "no-configure"
|
||||
|
||||
|
||||
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
class _Model:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"label": {"en_US": self.model_name},
|
||||
"model_type": ModelType.LLM,
|
||||
"features": [],
|
||||
"fetch_from": FetchFrom.PREDEFINED_MODEL,
|
||||
"model_properties": {},
|
||||
"deprecated": False,
|
||||
"status": ModelStatus.ACTIVE,
|
||||
"load_balancing_enabled": False,
|
||||
"has_invalid_load_balancing_configs": False,
|
||||
"provider": {
|
||||
"provider": "openai",
|
||||
"label": {"en_US": "OpenAI"},
|
||||
"icon_small": None,
|
||||
"icon_small_dark": None,
|
||||
"supported_model_types": [ModelType.LLM],
|
||||
},
|
||||
}
|
||||
|
||||
provider_configurations = SimpleNamespace(
|
||||
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
|
||||
)
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "gpt-4o"
|
||||
assert result[1].provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"get_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
{"token": "abc"},
|
||||
),
|
||||
(
|
||||
"validate_provider_credentials",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
|
||||
"validate_provider_credentials",
|
||||
({"token": "abc"},),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
|
||||
"create_provider_credential",
|
||||
({"token": "abc"}, "A"),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "B",
|
||||
},
|
||||
"update_provider_credential",
|
||||
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"delete_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"switch_active_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_provider_credential_methods_should_delegate_to_provider_configuration(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
provider_call_kwargs: Any,
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"get_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
{"api_key": "x"},
|
||||
),
|
||||
(
|
||||
"validate_model_credentials",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
"validate_custom_model_credentials",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
"create_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
"update_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"delete_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_custom_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"switch_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"add_model_credential_to_model_list",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"add_model_credential_to_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"delete_custom_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_model_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
expected_kwargs: dict[str, Any],
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
|
||||
if method_name == "get_model_credential":
|
||||
assert result == {"api_key": "x"}
|
||||
|
||||
|
||||
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
openai_provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
anthropic_provider = SimpleNamespace(
|
||||
provider="anthropic",
|
||||
label=I18nObject(en_US="Anthropic"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
models = [
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="gpt-4o",
|
||||
label=I18nObject(en_US="GPT-4o"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=False,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="old-openai",
|
||||
label=I18nObject(en_US="Old OpenAI"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=anthropic_provider,
|
||||
model="old-anthropic",
|
||||
label=I18nObject(en_US="Old Anthropic"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
]
|
||||
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert len(result[0].models) == 1
|
||||
assert result[0].models[0].model == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "schema", "expected_count"),
|
||||
[
|
||||
(None, None, 0),
|
||||
({"api_key": "x"}, None, 0),
|
||||
(
|
||||
{"api_key": "x"},
|
||||
SimpleNamespace(
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
|
||||
credentials: dict[str, Any] | None,
|
||||
schema: Any,
|
||||
expected_count: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.get_current_credentials.return_value = credentials
|
||||
provider_configuration.get_model_schema.return_value = schema
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
|
||||
|
||||
# Assert
|
||||
assert len(result) == expected_count
|
||||
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
|
||||
if credentials:
|
||||
provider_configuration.get_model_schema.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
provider_configuration.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = SimpleNamespace(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
provider=SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4o"
|
||||
assert result.provider.provider == "openai"
|
||||
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.side_effect = RuntimeError("boom")
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
# Act
|
||||
service.update_default_model_of_model_type(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM.value,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Assert
|
||||
manager.update_default_model_record.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
|
||||
factory_constructor = MagicMock(return_value=factory_instance)
|
||||
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
|
||||
|
||||
# Act
|
||||
result = service.get_model_provider_icon(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
icon_type="icon_small",
|
||||
lang="en_US",
|
||||
)
|
||||
|
||||
# Assert
|
||||
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
|
||||
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
assert result == (b"icon-bytes", "image/png")
|
||||
|
||||
|
||||
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
service.switch_preferred_provider(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
preferred_provider_type=ProviderType.SYSTEM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "provider_method_name"),
|
||||
[
|
||||
("enable_model", "enable_model"),
|
||||
("disable_model", "disable_model"),
|
||||
],
|
||||
)
|
||||
def test_model_enablement_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
provider_method_name: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
getattr(service, method_name)(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
@@ -316,7 +316,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result == expected_detail
|
||||
@@ -346,7 +346,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result["name"] == f"App from {mode}"
|
||||
@@ -369,7 +369,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
@@ -392,7 +392,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
@@ -432,9 +432,197 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result["model_config"] == complex_model_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
|
||||
|
||||
# === Merged from test_recommended_app_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services import recommended_app_service as service_module
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], result)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session))
|
||||
|
||||
# Assert
|
||||
return session
|
||||
|
||||
|
||||
def _mock_factory_for_apps(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
mode: str,
|
||||
result: dict[str, Any],
|
||||
fallback_result: dict[str, Any] | None = None,
|
||||
) -> tuple[MagicMock, MagicMock]:
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommended_apps_and_categories.return_value = result
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
|
||||
builtin_instance = MagicMock()
|
||||
if fallback_result is not None:
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_buildin_recommend_app_retrieval",
|
||||
MagicMock(return_value=builtin_instance),
|
||||
)
|
||||
return retrieval_instance, builtin_instance
|
||||
|
||||
|
||||
def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
|
||||
retrieval_instance, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch,
|
||||
mode="remote",
|
||||
result=expected,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
|
||||
mocked_db_session.scalar.assert_not_called()
|
||||
|
||||
|
||||
def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
remote_result = {"recommended_apps": [], "categories": []}
|
||||
fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]}
|
||||
_, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch,
|
||||
mode="remote",
|
||||
result=remote_result,
|
||||
fallback_result=fallback_result,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None]
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
|
||||
|
||||
# Assert
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
assert result["recommended_apps"][0]["can_trial"] is True
|
||||
assert result["recommended_apps"][1]["can_trial"] is False
|
||||
assert mocked_db_session.scalar.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trial_query_result", "expected_can_trial"),
|
||||
[
|
||||
(SimpleNamespace(id="trial"), True),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
trial_query_result: Any,
|
||||
expected_can_trial: bool,
|
||||
) -> None:
|
||||
# Arrange
|
||||
detail = {"id": "app-1", "name": "Test App"}
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommend_app_detail.return_value = detail
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
mocked_db_session.scalar.return_value = trial_query_result
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1"))
|
||||
|
||||
# Assert
|
||||
assert result["id"] == "app-1"
|
||||
assert result["can_trial"] is expected_can_trial
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trial_app_record_should_increment_count_when_existing_record_found(
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
existing_record = SimpleNamespace(count=3)
|
||||
mocked_db_session.scalar.return_value = existing_record
|
||||
|
||||
# Act
|
||||
RecommendedAppService.add_trial_app_record("app-1", "account-1")
|
||||
|
||||
# Assert
|
||||
assert existing_record.count == 4
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
mocked_db_session.commit.assert_called_once()
|
||||
mocked_db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_add_trial_app_record_should_create_new_record_when_no_existing_record(
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocked_db_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
RecommendedAppService.add_trial_app_record("app-2", "account-2")
|
||||
|
||||
# Assert
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
mocked_db_session.add.assert_called_once()
|
||||
added = mocked_db_session.add.call_args.args[0]
|
||||
assert added.app_id == "app-2"
|
||||
assert added.account_id == "account-2"
|
||||
assert added.count == 1
|
||||
mocked_db_session.commit.assert_called_once()
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from events.event_handlers.sync_workflow_schedule_when_app_published import (
|
||||
sync_schedule_from_workflow,
|
||||
)
|
||||
@@ -14,6 +17,8 @@ from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from models.workflow import Workflow
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.trigger import schedule_service as service_module
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
|
||||
@@ -775,5 +780,158 @@ class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_mock() -> MagicMock:
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_update_schedule_should_update_only_node_id_without_recomputing_time(
|
||||
session_mock: MagicMock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
schedule = MagicMock(spec=WorkflowSchedulePlan)
|
||||
schedule.cron_expression = "0 10 * * *"
|
||||
schedule.timezone = "UTC"
|
||||
session_mock.get.return_value = schedule
|
||||
|
||||
next_run_mock = MagicMock(return_value=datetime(2026, 1, 1, 10, 0, tzinfo=UTC))
|
||||
monkeypatch.setattr(service_module, "calculate_next_run_at", next_run_mock)
|
||||
|
||||
# Act
|
||||
result = ScheduleService.update_schedule(
|
||||
session=session_mock,
|
||||
schedule_id="schedule-1",
|
||||
updates=SchedulePlanUpdate(node_id="node-new"),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is schedule
|
||||
assert schedule.node_id == "node-new"
|
||||
next_run_mock.assert_not_called()
|
||||
session_mock.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tenant_owner_should_raise_when_account_record_missing(session_mock: MagicMock) -> None:
|
||||
# Arrange
|
||||
join = SimpleNamespace(account_id="account-404")
|
||||
session_mock.execute.return_value.scalar_one_or_none.return_value = join
|
||||
session_mock.get.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AccountNotFoundError, match="Account not found: account-404"):
|
||||
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_get_tenant_owner_should_raise_when_no_owner_or_admin_found(session_mock: MagicMock) -> None:
|
||||
# Arrange
|
||||
session_mock.execute.return_value.scalar_one_or_none.side_effect = [None, None]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AccountNotFoundError, match="Account not found for tenant: tenant-1"):
|
||||
ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_update_next_run_at_should_raise_when_schedule_not_found(session_mock: MagicMock) -> None:
|
||||
# Arrange
|
||||
session_mock.get.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleNotFoundError, match="Schedule not found: schedule-1"):
|
||||
ScheduleService.update_next_run_at(session=session_mock, schedule_id="schedule-1")
|
||||
|
||||
|
||||
def test_to_schedule_config_should_build_from_cron_mode() -> None:
|
||||
# Arrange
|
||||
node_config: dict[str, Any] = {
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"mode": "cron",
|
||||
"cron_expression": "0 12 * * *",
|
||||
"timezone": "Asia/Kolkata",
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
# Assert
|
||||
assert result.node_id == "node-1"
|
||||
assert result.cron_expression == "0 12 * * *"
|
||||
assert result.timezone == "Asia/Kolkata"
|
||||
|
||||
|
||||
def test_to_schedule_config_should_raise_for_cron_mode_without_expression() -> None:
|
||||
# Arrange
|
||||
node_config = {"id": "node-1", "data": {"mode": "cron", "cron_expression": ""}}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Cron expression is required for cron mode"):
|
||||
ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
|
||||
def test_to_schedule_config_should_build_from_visual_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
node_config = {
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "9:30 AM"},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr(ScheduleService, "visual_to_cron", MagicMock(return_value="30 9 * * *"))
|
||||
|
||||
# Act
|
||||
result = ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
# Assert
|
||||
assert result.cron_expression == "30 9 * * *"
|
||||
|
||||
|
||||
def test_to_schedule_config_should_raise_for_invalid_mode() -> None:
|
||||
# Arrange
|
||||
node_config = {"id": "node-1", "data": {"mode": "manual"}}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: manual"):
|
||||
ScheduleService.to_schedule_config(node_config=node_config)
|
||||
|
||||
|
||||
def test_extract_schedule_config_should_raise_when_graph_is_empty() -> None:
|
||||
# Arrange
|
||||
workflow = _workflow(graph_dict={})
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
|
||||
ScheduleService.extract_schedule_config(workflow=workflow)
|
||||
|
||||
|
||||
def test_extract_schedule_config_should_raise_when_mode_invalid() -> None:
|
||||
# Arrange
|
||||
workflow = _workflow(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "schedule-1",
|
||||
"data": {
|
||||
"type": TRIGGER_SCHEDULE_NODE_TYPE,
|
||||
"mode": "invalid",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ScheduleConfigError, match="Invalid schedule mode: invalid"):
|
||||
ScheduleService.extract_schedule_config(workflow=workflow)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -12,6 +12,7 @@ This test suite covers all functionality of the current VariableTruncator includ
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -199,14 +200,14 @@ class TestArrayTruncation:
|
||||
|
||||
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that small arrays are not truncated."""
|
||||
small_array = [1, 2]
|
||||
small_array: list[object] = [1, 2]
|
||||
result = small_truncator._truncate_array(small_array, 1000)
|
||||
assert result.value == small_array
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that arrays over element limit are truncated."""
|
||||
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
large_array: list[object] = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
result = small_truncator._truncate_array(large_array, 1000)
|
||||
|
||||
assert result.truncated is True
|
||||
@@ -215,7 +216,7 @@ class TestArrayTruncation:
|
||||
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test array truncation due to size budget constraints."""
|
||||
# Create array with strings that will exceed size budget
|
||||
large_strings = ["very long string " * 5, "another long string " * 5]
|
||||
large_strings: list[object] = ["very long string " * 5, "another long string " * 5]
|
||||
result = small_truncator._truncate_array(large_strings, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
@@ -276,10 +277,10 @@ class TestObjectTruncation:
|
||||
|
||||
# Values should be truncated if they exist
|
||||
for key, value in result.value.items():
|
||||
if isinstance(value, str):
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
assert isinstance(value, str)
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
|
||||
def test_object_key_dropping(self, small_truncator):
|
||||
"""Test object truncation where keys are dropped due to size constraints."""
|
||||
@@ -506,10 +507,9 @@ class TestEdgeCases:
|
||||
truncator = VariableTruncator(string_length_limit=10)
|
||||
|
||||
# Unicode characters
|
||||
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
|
||||
unicode_text = "你好世界你好世界你好世界" # Multi-byte UTF-8 characters
|
||||
result = truncator.truncate(StringSegment(value=unicode_text))
|
||||
if len(unicode_text) > 10:
|
||||
assert result.truncated is True
|
||||
assert result.truncated is True
|
||||
|
||||
# Special JSON characters
|
||||
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
|
||||
@@ -631,13 +631,12 @@ class TestIntegrationScenarios:
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
# Should handle all data types appropriately
|
||||
if result.truncated:
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
if isinstance(result.result, ObjectSegment):
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ObjectSegment)
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_file_and_array_file_variable_mapping(self, file):
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
@@ -675,3 +674,229 @@ def test_dummy_variable_truncator_methods():
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
# === Merged from test_variable_truncator_additional.py ===
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
from services import variable_truncator as truncator_module
|
||||
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
|
||||
|
||||
|
||||
class _AbstractPassthrough(BaseTruncator):
|
||||
def truncate(self, segment: Any) -> TruncationResult:
|
||||
# Arrange / Act
|
||||
return super().truncate(segment) # type: ignore[misc]
|
||||
|
||||
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||
# Arrange / Act
|
||||
return super().truncate_variable_mapping(v) # type: ignore[misc]
|
||||
|
||||
|
||||
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
|
||||
# Arrange
|
||||
passthrough = _AbstractPassthrough()
|
||||
|
||||
# Act
|
||||
truncate_result = passthrough.truncate(StringSegment(value="x"))
|
||||
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
|
||||
|
||||
# Assert
|
||||
assert truncate_result is None
|
||||
assert mapping_result is None
|
||||
|
||||
|
||||
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
|
||||
|
||||
# Act
|
||||
truncator = VariableTruncator.default()
|
||||
|
||||
# Assert
|
||||
assert truncator._max_size_bytes == 111
|
||||
assert truncator._array_element_limit == 7
|
||||
assert truncator._string_length_limit == 33
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=5)
|
||||
mapping = {"very_long_key": "value"}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert result == {"very_long_key": "..."}
|
||||
assert truncated is True
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
mapping = {"seg": StringSegment(value="hello")}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result["seg"], StringSegment)
|
||||
assert result["seg"].value == "hello"
|
||||
assert truncated is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
(True, False),
|
||||
(1, False),
|
||||
(1.5, False),
|
||||
("x", True),
|
||||
({"k": "v"}, True),
|
||||
],
|
||||
)
|
||||
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = VariableTruncator._json_value_needs_truncation(value)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_result = truncator_module._PartResult(
|
||||
value=StringSegment(value="this is too long"),
|
||||
value_size=100,
|
||||
truncated=True,
|
||||
)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(StringSegment(value="input"))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert not result.result.value.startswith('"')
|
||||
|
||||
|
||||
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_segment(IntegerSegment(value=1), 10)
|
||||
|
||||
|
||||
def test_calculate_json_size_should_unwrap_segment_values() -> None:
|
||||
# Arrange
|
||||
segment = StringSegment(value="abc")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(segment)
|
||||
|
||||
# Assert
|
||||
assert size == VariableTruncator.calculate_json_size("abc")
|
||||
|
||||
|
||||
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
|
||||
# Arrange
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(updated)
|
||||
|
||||
# Assert
|
||||
assert size > 0
|
||||
|
||||
|
||||
def test_maybe_qa_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
|
||||
assert VariableTruncator._maybe_qa_structure({}) is False
|
||||
|
||||
|
||||
def test_maybe_parent_child_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
|
||||
)
|
||||
|
||||
|
||||
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
|
||||
mapping = {"s": StringSegment(value="long-content")}
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_object(mapping, 20)
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value["s"], StringSegment)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_json_primitives(updated, 100)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_segment = ObjectSegment(value={"k": "v"})
|
||||
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
@@ -559,3 +559,757 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
|
||||
# === Merged from test_webhook_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value = _FakeQuery(None)
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1", is_debug=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
# Act / Assert
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
# Act
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
# Assert
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
# Act
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
# Assert
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
|
||||
# Arrange
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters() -> None:
|
||||
# Arrange
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
# Assert
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
|
||||
# Arrange
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
# Assert
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
|
||||
# Arrange
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
# Assert
|
||||
assert True
|
||||
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
|
||||
# Arrange
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
# Assert
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
|
||||
# Arrange
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
|
||||
# Act
|
||||
result = WebhookService._extract_content_type(headers)
|
||||
|
||||
# Assert
|
||||
assert result == "application/json"
|
||||
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
# Act
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Assert
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
# Act
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Assert
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
assert fake_session.commit_count == 2
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
|
||||
# Arrange
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
# Act
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
# Assert
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
# Assert
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
|
||||
# Assert
|
||||
assert result == 123
|
||||
|
||||
@@ -176,3 +176,300 @@ class TestWorkflowRunService:
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
|
||||
|
||||
# === Merged from test_workflow_run_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
|
||||
from services import workflow_run_service as service_module
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
|
||||
# Arrange
|
||||
node_repo = MagicMock()
|
||||
workflow_run_repo = MagicMock()
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
|
||||
# Assert
|
||||
return node_repo, workflow_run_repo, factory
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService()
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_when_engine_is_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
class FakeEngine:
|
||||
pass
|
||||
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "Engine", FakeEngine)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
engine = cast(Engine, FakeEngine())
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=engine)
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, workflow_run_repo, factory = repository_factory_mocks
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=session_factory)
|
||||
|
||||
# Assert
|
||||
assert service._session_factory is session_factory
|
||||
assert service._node_execution_service_repo is node_repo
|
||||
assert service._workflow_run_repo is workflow_run_repo
|
||||
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
|
||||
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="pagination")
|
||||
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
|
||||
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
limit=7,
|
||||
last_id="last-1",
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
run_with_message = SimpleNamespace(
|
||||
id="run-1",
|
||||
status="running",
|
||||
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
|
||||
)
|
||||
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
|
||||
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
|
||||
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
|
||||
|
||||
# Assert
|
||||
assert result is pagination
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].message_id == "msg-1"
|
||||
assert result.data[0].conversation_id == "conv-1"
|
||||
assert result.data[0].status == "running"
|
||||
assert not hasattr(result.data[1], "message_id")
|
||||
assert result.data[1].id == "run-2"
|
||||
|
||||
|
||||
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="workflow_run")
|
||||
workflow_run_repo.get_workflow_run_by_id.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_runs_count_should_forward_optional_filters(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = {"total": 3, "succeeded": 2}
|
||||
workflow_run_repo.get_workflow_runs_count.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-1")
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
|
||||
class FakeEndUser:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
|
||||
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
|
||||
app_model = _app_model(id="app-1")
|
||||
expected = [SimpleNamespace(id="exec-1")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-end-user",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-account")
|
||||
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-account",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id=None)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="tenant_id cannot be None"):
|
||||
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
@@ -0,0 +1,831 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import Account, App, AppMode, AppModelConfig
|
||||
from services.workflow import workflow_converter as converter_module
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
try:
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
except ModuleNotFoundError:
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter() -> WorkflowConverter:
|
||||
return WorkflowConverter()
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app_model_config(**kwargs: Any) -> AppModelConfig:
|
||||
return cast(AppModelConfig, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _build_start_graph() -> dict[str, Any]:
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"position": None,
|
||||
"data": {"type": BuiltinNodeTypes.START, "variables": [{"variable": "name"}, {"variable": "city"}]},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
def _build_model_config(mode: str | LLMMode) -> ModelConfigEntity:
|
||||
return ModelConfigEntity(provider="openai", model="gpt-4", mode=mode, parameters={}, stop=[])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_variables() -> list[VariableEntity]:
|
||||
return [
|
||||
VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT),
|
||||
VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH),
|
||||
VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT),
|
||||
]
|
||||
|
||||
|
||||
def test__convert_to_start_node(default_variables: list[VariableEntity]) -> None:
|
||||
result = WorkflowConverter()._convert_to_start_node(default_variables)
|
||||
|
||||
assert result["id"] == "start"
|
||||
assert result["data"]["type"] == BuiltinNodeTypes.START
|
||||
assert result["data"]["variables"][0]["type"] == "text-input"
|
||||
assert result["data"]["variables"][0]["variable"] == "text_input"
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_chatbot(default_variables: list[VariableEntity]) -> None:
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
extension.id = "api_based_extension_id"
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable",
|
||||
type="api",
|
||||
config={"api_based_extension_id": "api_based_extension_id"},
|
||||
),
|
||||
]
|
||||
|
||||
nodes, mapping = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0]["data"]["type"] == BuiltinNodeTypes.HTTP_REQUEST
|
||||
assert nodes[1]["data"]["type"] == BuiltinNodeTypes.CODE
|
||||
body = json.loads(nodes[0]["data"]["body"]["data"])
|
||||
assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
assert body["params"]["query"] == "{{#sys.query#}}"
|
||||
assert body["params"]["inputs"]["text_input"] == "{{#start.text_input#}}"
|
||||
assert mapping == {"external_variable": "code_1"}
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_workflow_app(default_variables: list[VariableEntity]) -> None:
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
extension.id = "api_based_extension_id"
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=extension)
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable",
|
||||
type="api",
|
||||
config={"api_based_extension_id": "api_based_extension_id"},
|
||||
),
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
body = json.loads(nodes[0]["data"]["body"]["data"])
|
||||
assert body["params"]["query"] == ""
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_chatbot() -> None:
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert node is not None
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
assert node["data"]["multiple_retrieval_config"]["top_k"] == 5
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_workflow_app() -> None:
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable="query",
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert node is not None
|
||||
assert node["data"]["query_variable_selector"] == ["start", "query"]
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables: list[VariableEntity]) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helper for {{text_input}} and {{paragraph}}",
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["type"] == BuiltinNodeTypes.LLM
|
||||
assert node["data"]["memory"] is not None
|
||||
assert node["data"]["prompt_template"][0]["role"] == "user"
|
||||
assert "{{#start.text_input#}}" in node["data"]["prompt_template"][0]["text"]
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model_with_empty_template(
|
||||
default_variables: list[VariableEntity],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="ignored",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter_module.SimplePromptTransform,
|
||||
"get_prompt_template",
|
||||
lambda self, **kwargs: {"prompt_template": PromptTemplateParser(""), "prompt_rules": {}},
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["prompt_template"] == []
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables: list[VariableEntity]) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[AdvancedChatMessageEntity(text="Hello {{text_input}}", role=PromptMessageRole.USER)]
|
||||
),
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert isinstance(node["data"]["prompt_template"], list)
|
||||
assert node["data"]["prompt_template"][0]["role"] == PromptMessageRole.USER.value
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model_without_template(
|
||||
default_variables: list[VariableEntity],
|
||||
) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode=LLMMode.CHAT.value, parameters={}, stop=[])
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=None,
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["prompt_template"] == []
|
||||
assert node["data"]["memory"] is None
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables: list[VariableEntity]) -> None:
|
||||
workflow_converter = WorkflowConverter()
|
||||
graph = {"nodes": [workflow_converter._convert_to_start_node(default_variables)], "edges": []}
|
||||
model_config = ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
mode=LLMMode.COMPLETION.value,
|
||||
parameters={},
|
||||
stop=[],
|
||||
)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="Hello {{text_input}} and {{#query#}}",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
|
||||
),
|
||||
)
|
||||
|
||||
node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
model_config=model_config,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert node["data"]["prompt_template"]["text"].find("{{#sys.query#}}") != -1
|
||||
assert node["data"]["memory"]["role_prefix"]["user"] == "Human"
|
||||
|
||||
|
||||
def test__convert_to_end_node() -> None:
|
||||
node = WorkflowConverter()._convert_to_end_node()
|
||||
assert node["id"] == "end"
|
||||
assert node["data"]["type"] == BuiltinNodeTypes.END
|
||||
|
||||
|
||||
def test__convert_to_answer_node() -> None:
|
||||
node = WorkflowConverter()._convert_to_answer_node()
|
||||
assert node["id"] == "answer"
|
||||
assert node["data"]["type"] == BuiltinNodeTypes.ANSWER
|
||||
|
||||
|
||||
def test_convert_to_workflow_should_raise_when_app_model_config_is_missing(converter: WorkflowConverter) -> None:
|
||||
app_model = _app_model(app_model_config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="App model config is required"):
|
||||
converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=_account(id="account-1"),
|
||||
name="new-app",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("source_mode", "expected_mode"),
|
||||
[
|
||||
(AppMode.CHAT, AppMode.ADVANCED_CHAT),
|
||||
(AppMode.COMPLETION, AppMode.WORKFLOW),
|
||||
],
|
||||
)
|
||||
def test_convert_to_workflow_should_create_new_app_with_fallback_fields(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
source_mode: AppMode,
|
||||
expected_mode: AppMode,
|
||||
) -> None:
|
||||
class FakeApp:
|
||||
def __init__(self) -> None:
|
||||
self.id = "new-app-id"
|
||||
|
||||
workflow = SimpleNamespace(app_id=None)
|
||||
monkeypatch.setattr(converter, "convert_app_model_config_to_workflow", MagicMock(return_value=workflow))
|
||||
monkeypatch.setattr(converter_module, "App", FakeApp)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), flush=MagicMock(), commit=MagicMock())
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(converter_module.app_was_created, "send", send_mock)
|
||||
|
||||
account = _account(id="account-1")
|
||||
app_model = _app_model(
|
||||
tenant_id="tenant-1",
|
||||
name="Source App",
|
||||
mode=source_mode,
|
||||
icon_type="emoji",
|
||||
icon="sparkles",
|
||||
icon_background="#123456",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=10,
|
||||
api_rph=100,
|
||||
is_public=False,
|
||||
app_model_config=_app_model_config(id="config-1"),
|
||||
)
|
||||
|
||||
new_app = converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name="",
|
||||
icon_type="",
|
||||
icon="",
|
||||
icon_background="",
|
||||
)
|
||||
|
||||
assert new_app.name == "Source App(workflow)"
|
||||
assert new_app.mode == expected_mode
|
||||
assert new_app.icon_type == "emoji"
|
||||
assert new_app.icon == "sparkles"
|
||||
assert new_app.icon_background == "#123456"
|
||||
assert new_app.created_by == "account-1"
|
||||
assert workflow.app_id == "new-app-id"
|
||||
db_session.add.assert_called_once()
|
||||
db_session.flush.assert_called_once()
|
||||
db_session.commit.assert_called_once()
|
||||
send_mock.assert_called_once_with(new_app, account=account)
|
||||
|
||||
|
||||
def test_convert_app_model_config_to_workflow_should_build_advanced_chat_graph_and_features(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
|
||||
app_config = SimpleNamespace(
|
||||
variables=[SimpleNamespace(variable="name")],
|
||||
external_data_variables=[SimpleNamespace(variable="ext")],
|
||||
dataset=SimpleNamespace(id="dataset"),
|
||||
model=SimpleNamespace(),
|
||||
prompt_template=SimpleNamespace(),
|
||||
additional_features=SimpleNamespace(file_upload=SimpleNamespace()),
|
||||
app_model_config_dict={
|
||||
"opening_statement": "hello",
|
||||
"suggested_questions": ["q1"],
|
||||
"suggested_questions_after_answer": True,
|
||||
"speech_to_text": True,
|
||||
"text_to_speech": {"enabled": True},
|
||||
"file_upload": {"enabled": True},
|
||||
"sensitive_word_avoidance": {"enabled": True},
|
||||
"retriever_resource": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
class FakeWorkflow:
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.ADVANCED_CHAT))
|
||||
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_start_node",
|
||||
MagicMock(
|
||||
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_http_request_node",
|
||||
MagicMock(
|
||||
return_value=(
|
||||
[{"id": "http", "position": None, "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}],
|
||||
{"ext": "code_1"},
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_knowledge_retrieval_node",
|
||||
MagicMock(
|
||||
return_value={"id": "knowledge", "position": None, "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_llm_node",
|
||||
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_answer_node",
|
||||
MagicMock(return_value={"id": "answer", "position": None, "data": {"type": BuiltinNodeTypes.ANSWER}}),
|
||||
)
|
||||
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
workflow = converter.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=_app_model_config(id="cfg"),
|
||||
account_id="account-1",
|
||||
)
|
||||
|
||||
graph = json.loads(workflow.graph)
|
||||
node_ids = [node["id"] for node in graph["nodes"]]
|
||||
assert node_ids == ["start", "http", "knowledge", "llm", "answer"]
|
||||
|
||||
features = json.loads(workflow.features)
|
||||
assert "opening_statement" in features
|
||||
assert "retriever_resource" in features
|
||||
db_session.add.assert_called_once()
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_convert_app_model_config_to_workflow_should_build_workflow_mode_with_end_node(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.COMPLETION)
|
||||
app_config = SimpleNamespace(
|
||||
variables=[SimpleNamespace(variable="name")],
|
||||
external_data_variables=[],
|
||||
dataset=SimpleNamespace(id="dataset"),
|
||||
model=SimpleNamespace(),
|
||||
prompt_template=SimpleNamespace(),
|
||||
additional_features=None,
|
||||
app_model_config_dict={
|
||||
"text_to_speech": {"enabled": False},
|
||||
"file_upload": {"enabled": False},
|
||||
"sensitive_word_avoidance": {"enabled": False},
|
||||
},
|
||||
)
|
||||
|
||||
class FakeWorkflow:
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(converter, "_get_new_app_mode", MagicMock(return_value=AppMode.WORKFLOW))
|
||||
monkeypatch.setattr(converter, "_convert_to_app_config", MagicMock(return_value=app_config))
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_start_node",
|
||||
MagicMock(
|
||||
return_value={"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START, "variables": []}}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(converter, "_convert_to_knowledge_retrieval_node", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_llm_node",
|
||||
MagicMock(return_value={"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter,
|
||||
"_convert_to_end_node",
|
||||
MagicMock(return_value={"id": "end", "position": None, "data": {"type": BuiltinNodeTypes.END}}),
|
||||
)
|
||||
monkeypatch.setattr(converter_module, "Workflow", FakeWorkflow)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock())
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
workflow = converter.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=_app_model_config(id="cfg"),
|
||||
account_id="account-1",
|
||||
)
|
||||
|
||||
graph = json.loads(workflow.graph)
|
||||
node_ids = [node["id"] for node in graph["nodes"]]
|
||||
assert node_ids == ["start", "llm", "end"]
|
||||
|
||||
features = json.loads(workflow.features)
|
||||
assert set(features.keys()) == {"text_to_speech", "file_upload", "sensitive_word_avoidance"}
|
||||
|
||||
|
||||
def test_convert_to_app_config_should_route_to_correct_manager(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
agent_result = SimpleNamespace(kind="agent")
|
||||
chat_result = SimpleNamespace(kind="chat")
|
||||
completion_result = SimpleNamespace(kind="completion")
|
||||
monkeypatch.setattr(
|
||||
converter_module.AgentChatAppConfigManager, "get_app_config", MagicMock(return_value=agent_result)
|
||||
)
|
||||
monkeypatch.setattr(converter_module.ChatAppConfigManager, "get_app_config", MagicMock(return_value=chat_result))
|
||||
monkeypatch.setattr(
|
||||
converter_module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
MagicMock(return_value=completion_result),
|
||||
)
|
||||
|
||||
from_agent_mode = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.AGENT_CHAT, is_agent=False),
|
||||
app_model_config=_app_model_config(id="cfg-1"),
|
||||
)
|
||||
from_agent_flag = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.CHAT, is_agent=True),
|
||||
app_model_config=_app_model_config(id="cfg-2"),
|
||||
)
|
||||
from_chat_mode = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.CHAT, is_agent=False),
|
||||
app_model_config=_app_model_config(id="cfg-3"),
|
||||
)
|
||||
from_completion_mode = converter._convert_to_app_config(
|
||||
app_model=_app_model(mode=AppMode.COMPLETION, is_agent=False),
|
||||
app_model_config=_app_model_config(id="cfg-4"),
|
||||
)
|
||||
|
||||
assert from_agent_mode is agent_result
|
||||
assert from_agent_flag is agent_result
|
||||
assert from_chat_mode is chat_result
|
||||
assert from_completion_mode is completion_result
|
||||
|
||||
|
||||
def test_convert_to_app_config_should_raise_for_invalid_app_mode(converter: WorkflowConverter) -> None:
|
||||
app_model = _app_model(mode=AppMode.WORKFLOW, is_agent=False)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
converter._convert_to_app_config(app_model=app_model, app_model_config=_app_model_config(id="cfg"))
|
||||
|
||||
|
||||
def test_convert_to_http_request_node_should_skip_non_api_and_missing_extension_id(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
app_model = _app_model(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT)
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(variable="skip_type", type="dataset", config={"api_based_extension_id": "x"}),
|
||||
ExternalDataVariableEntity(variable="skip_config", type="api", config={}),
|
||||
]
|
||||
|
||||
nodes, mapping = converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=[],
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
assert nodes == []
|
||||
assert mapping == {}
|
||||
|
||||
|
||||
def test_convert_to_knowledge_retrieval_node_should_return_none_for_workflow_without_query_variable(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["ds-1"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=None,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
),
|
||||
)
|
||||
model_config = _build_model_config(mode=LLMMode.CHAT)
|
||||
|
||||
node = converter._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
assert node is None
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_simple_chat_template_missing(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.CHAT)
|
||||
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
|
||||
|
||||
with pytest.raises(ValueError, match="Simple prompt template is required"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_prompt_template_parser_type_is_invalid_for_chat(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.CHAT)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello {{name}}",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter_module.SimplePromptTransform,
|
||||
"get_prompt_template",
|
||||
lambda self, **kwargs: {"prompt_template": "invalid"},
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="Expected PromptTemplateParser"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_simple_completion_template_missing(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.COMPLETION)
|
||||
prompt_template = PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.SIMPLE)
|
||||
|
||||
with pytest.raises(ValueError, match="Simple prompt template is required"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_raise_when_completion_prompt_rules_type_is_invalid(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.COMPLETION)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello {{name}}",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
converter_module.SimplePromptTransform,
|
||||
"get_prompt_template",
|
||||
lambda self, **kwargs: {"prompt_template": PromptTemplateParser("Hello {{name}}"), "prompt_rules": "invalid"},
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="Expected dict for prompt_rules"):
|
||||
converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_llm_node_should_use_empty_text_for_advanced_completion_without_template(
|
||||
converter: WorkflowConverter,
|
||||
) -> None:
|
||||
graph = _build_start_graph()
|
||||
model_config = _build_model_config(mode=LLMMode.COMPLETION)
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=None,
|
||||
)
|
||||
|
||||
llm_node = converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.COMPLETION,
|
||||
new_app_mode=AppMode.WORKFLOW,
|
||||
graph=graph,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["prompt_template"]["text"] == ""
|
||||
assert llm_node["data"]["memory"] is None
|
||||
|
||||
|
||||
def test_replace_template_variables_should_replace_start_and_external_references(converter: WorkflowConverter) -> None:
|
||||
template = "Hello {{name}} from {{city}} with {{weather}}"
|
||||
variables = [{"variable": "name"}, {"variable": "city"}]
|
||||
external_mapping = {"weather": "code_1"}
|
||||
|
||||
result = converter._replace_template_variables(template, variables, external_mapping)
|
||||
|
||||
assert result == "Hello {{#start.name#}} from {{#start.city#}} with {{#code_1.result#}}"
|
||||
|
||||
|
||||
def test_graph_helpers_should_create_edges_append_nodes_and_choose_mode(converter: WorkflowConverter) -> None:
|
||||
graph = {"nodes": [{"id": "start", "position": None, "data": {"type": BuiltinNodeTypes.START}}], "edges": []}
|
||||
node = {"id": "llm", "position": None, "data": {"type": BuiltinNodeTypes.LLM}}
|
||||
|
||||
edge = converter._create_edge("start", "llm")
|
||||
updated_graph = converter._append_node(graph, node)
|
||||
workflow_mode = converter._get_new_app_mode(_app_model(mode=AppMode.COMPLETION))
|
||||
advanced_chat_mode = converter._get_new_app_mode(_app_model(mode=AppMode.CHAT))
|
||||
|
||||
assert edge == {"id": "start-llm", "source": "start", "target": "llm"}
|
||||
assert updated_graph["nodes"][-1]["id"] == "llm"
|
||||
assert updated_graph["edges"][-1]["source"] == "start"
|
||||
assert workflow_mode == AppMode.WORKFLOW
|
||||
assert advanced_chat_mode == AppMode.ADVANCED_CHAT
|
||||
|
||||
|
||||
def test_get_api_based_extension_should_raise_when_extension_not_found(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
db_session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
with pytest.raises(ValueError, match="API Based Extension not found"):
|
||||
converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
|
||||
db_session.scalar.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_based_extension_should_return_entity_when_found(
|
||||
converter: WorkflowConverter,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
extension = SimpleNamespace(id="ext-1")
|
||||
db_session = SimpleNamespace(scalar=MagicMock(return_value=extension))
|
||||
monkeypatch.setattr(converter_module, "db", SimpleNamespace(session=db_session))
|
||||
|
||||
result = converter._get_api_based_extension(tenant_id="tenant-1", api_based_extension_id="ext-1")
|
||||
|
||||
assert result is extension
|
||||
db_session.scalar.assert_called_once()
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from itertools import cycle
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
@@ -224,3 +223,577 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
|
||||
|
||||
# === Merged from test_workflow_event_snapshot_service_additional.py ===
|
||||
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services import workflow_event_snapshot_service as service_module
|
||||
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
|
||||
|
||||
|
||||
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"query": "hello"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.outputs = {"answer": "ok"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionMaker:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __call__(self) -> _SessionContext:
|
||||
return _SessionContext(self._session)
|
||||
|
||||
|
||||
class _SubscriptionContext:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._subscription
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _Topic:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def subscribe(self) -> _SubscriptionContext:
|
||||
return _SubscriptionContext(self._subscription)
|
||||
|
||||
|
||||
class _StaticSubscription:
|
||||
def receive(self, timeout: int = 1) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PauseEntity(WorkflowPauseEntity):
|
||||
state: bytes
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return "pause-1"
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return "run-1"
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
return self.state
|
||||
|
||||
def get_pause_reasons(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def test_get_message_context_should_return_none_when_no_message() -> None:
|
||||
# Arrange
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
|
||||
# Arrange
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
created_at=None,
|
||||
answer="answer",
|
||||
)
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=message))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.created_at == 0
|
||||
assert result.message_id == "msg-1"
|
||||
assert result.conversation_id == "conv-1"
|
||||
assert result.answer == "answer"
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(None)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
|
||||
# Arrange
|
||||
pause_entity = _PauseEntity(state=b"not-a-valid-state")
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
|
||||
# Arrange
|
||||
context = _build_resumption_context_additional(task_id="task-ctx")
|
||||
pause_entity = _PauseEntity(state=context.dumps().encode())
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.get_generate_entity().task_id == "task-ctx"
|
||||
|
||||
|
||||
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._resolve_task_id(
|
||||
resumption_context=None,
|
||||
buffer_state=None,
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(b'{"event":"node_started"}', {"event": "node_started"}),
|
||||
(b"invalid-json", None),
|
||||
(b"[]", None),
|
||||
],
|
||||
)
|
||||
def test_parse_event_message_should_parse_only_json_object(
|
||||
payload: bytes,
|
||||
expected: dict[str, Any] | None,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._parse_event_message(payload)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
|
||||
# Arrange
|
||||
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
|
||||
|
||||
# Act
|
||||
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
|
||||
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
|
||||
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
|
||||
|
||||
# Assert
|
||||
assert is_finished is True
|
||||
assert paused_without_flag is False
|
||||
assert paused_with_flag is True
|
||||
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
|
||||
|
||||
|
||||
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
|
||||
# Arrange
|
||||
payload: dict[str, Any] = {"event": "workflow_started"}
|
||||
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
|
||||
|
||||
# Act
|
||||
service_module._apply_message_context(payload, context)
|
||||
|
||||
# Assert
|
||||
assert payload["conversation_id"] == "conv-1"
|
||||
assert payload["message_id"] == "msg-1"
|
||||
assert payload["created_at"] == 1700000000
|
||||
|
||||
|
||||
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-1"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert buffer_state.task_id_hint == "task-1"
|
||||
assert event["event"] == "node_started"
|
||||
|
||||
|
||||
def test_start_buffering_should_drop_old_event_when_queue_is_full(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
class QueueWithSingleFull:
|
||||
def __init__(self) -> None:
|
||||
self._first_put = True
|
||||
self.items: list[dict[str, Any]] = [{"event": "old"}]
|
||||
|
||||
def put_nowait(self, item: dict[str, Any]) -> None:
|
||||
if self._first_put:
|
||||
self._first_put = False
|
||||
raise queue.Full
|
||||
self.items.append(item)
|
||||
|
||||
def get_nowait(self) -> dict[str, Any]:
|
||||
if not self.items:
|
||||
raise queue.Empty
|
||||
return self.items.pop(0)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.items) == 0
|
||||
|
||||
fake_queue = QueueWithSingleFull()
|
||||
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
|
||||
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-2"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert fake_queue.items[-1]["task_id"] == "task-2"
|
||||
|
||||
|
||||
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
raise RuntimeError("subscription failure")
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert finished is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_get_message_context",
|
||||
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_build_snapshot_events",
|
||||
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
|
||||
)
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
finished_event = cast(Mapping[str, Any], events[1])
|
||||
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
|
||||
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
|
||||
assert called_kwargs["workflow_run_id"] == "run-1"
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
|
||||
class AlwaysEmptyQueue:
|
||||
def empty(self) -> bool:
|
||||
return False
|
||||
|
||||
def get(self, timeout: int = 1) -> None:
|
||||
raise queue.Empty
|
||||
|
||||
buffer_state = BufferState(
|
||||
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
time_values = cycle([0.0, 6.0, 21.0, 26.0])
|
||||
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
idle_timeout=20.0,
|
||||
ping_interval=5.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
buffer_state.done_event.set()
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
assert snapshot_builder.call_args.kwargs["pause_entity"] is None
|
||||
|
||||
@@ -10,6 +10,8 @@ This module tests the document indexing task functionality including:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -1113,13 +1115,17 @@ class TestAdvancedScenarios:
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task)
|
||||
|
||||
# Assert
|
||||
# Verify delete was called to clean up task key
|
||||
mock_redis.delete.assert_called_once()
|
||||
expected_task_key = f"tenant_document_indexing_task:{tenant_id}"
|
||||
|
||||
# Verify the correct key was deleted (contains tenant_id and "document_indexing")
|
||||
delete_call_args = mock_redis.delete.call_args[0][0]
|
||||
assert tenant_id in delete_call_args
|
||||
assert "document_indexing" in delete_call_args
|
||||
# Verify the task key for this tenant was deleted (do not assert call count; fixtures may be shared).
|
||||
mock_redis.delete.assert_any_call(expected_task_key)
|
||||
|
||||
deleted_keys = [delete_call.args[0] for delete_call in mock_redis.delete.call_args_list if delete_call.args]
|
||||
assert expected_task_key in deleted_keys
|
||||
|
||||
deleted_task_key = next(key for key in deleted_keys if key == expected_task_key)
|
||||
assert tenant_id in deleted_task_key
|
||||
assert "document_indexing" in deleted_task_key
|
||||
|
||||
def test_billing_disabled_skips_limit_checks(
|
||||
self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service
|
||||
@@ -1510,3 +1516,475 @@ class TestRobustness:
|
||||
|
||||
# Verify the exception message
|
||||
assert "Feature service" in str(exc_info.value) or isinstance(exc_info.value, Exception)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: MagicMock) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> MagicMock:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Additional coverage for summary and tenant queue branches."""
|
||||
|
||||
def test_should_return_when_dataset_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test early return when dataset does not exist."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = None
|
||||
session.query.side_effect = lambda model: dataset_query
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
features_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.FeatureService.get_features", features_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
features_mock.assert_not_called()
|
||||
|
||||
def test_should_mark_documents_error_when_batch_upload_limit_exceeded(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Test batch upload limit triggers error handling."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(return_value=_SessionContext(session)),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(
|
||||
enabled=True,
|
||||
subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL),
|
||||
),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", "1")
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1", "doc-2"])
|
||||
|
||||
# Assert
|
||||
assert document.indexing_status == "error"
|
||||
assert "batch upload limit" in document.error
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_should_queue_summary_generation_for_completed_documents(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation is queued for eligible documents."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
|
||||
doc_eligible = SimpleNamespace(
|
||||
id="doc-1",
|
||||
indexing_status="completed",
|
||||
doc_form="text",
|
||||
need_summary=True,
|
||||
)
|
||||
doc_skip_form = SimpleNamespace(
|
||||
id="doc-2",
|
||||
indexing_status="completed",
|
||||
doc_form="qa_model",
|
||||
need_summary=True,
|
||||
)
|
||||
doc_skip_status = SimpleNamespace(
|
||||
id="doc-3",
|
||||
indexing_status="processing",
|
||||
doc_form="text",
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_docs = [SimpleNamespace(id="doc-1"), SimpleNamespace(id="doc-2"), SimpleNamespace(id="doc-3")]
|
||||
phase1_document_query = MagicMock()
|
||||
phase1_document_query.where.return_value = phase1_document_query
|
||||
phase1_document_query.all.return_value = phase1_docs
|
||||
|
||||
summary_document_query = MagicMock()
|
||||
summary_document_query.where.return_value = summary_document_query
|
||||
summary_document_query.all.return_value = [doc_eligible, doc_skip_form, doc_skip_status]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_document_query
|
||||
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
|
||||
|
||||
create_session_mock = MagicMock(
|
||||
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
indexing_runner = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1", "doc-2", "doc-3"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
|
||||
|
||||
def test_should_continue_when_summary_queue_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary queueing errors are swallowed."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
|
||||
doc_eligible = SimpleNamespace(
|
||||
id="doc-1",
|
||||
indexing_status="completed",
|
||||
doc_form="text",
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [doc_eligible]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
indexing_runner = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=indexing_runner))
|
||||
delay_mock = MagicMock(side_effect=Exception("boom"))
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_called_once_with("dataset-1", "doc-1", None)
|
||||
|
||||
def test_should_return_when_dataset_missing_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test early return when dataset is missing after indexing."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.side_effect = [dataset, None]
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
session3.query.assert_called()
|
||||
|
||||
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation skipped when indexing_technique is not high_quality."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="economy",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
|
||||
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_should_skip_summary_generation_when_indexing_paused(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation is skipped when indexing is paused."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
runner = MagicMock()
|
||||
runner.run.side_effect = DocumentIsPausedError("paused")
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_should_handle_indexing_runner_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test generic indexing runner exception is handled."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
|
||||
runner = MagicMock()
|
||||
runner.run.side_effect = RuntimeError("boom")
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=runner))
|
||||
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_should_log_missing_document_entry_in_summary_list(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test falsey document entries are handled in summary iteration."""
|
||||
|
||||
# Arrange
|
||||
class _FalseyDocument:
|
||||
def __init__(self, doc_id: str) -> None:
|
||||
self.id = doc_id
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
tenant_id="tenant-1",
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]),
|
||||
)
|
||||
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
vector_space=SimpleNamespace(limit=0, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.FeatureService.get_features", MagicMock(return_value=features)
|
||||
)
|
||||
monkeypatch.setattr("tasks.document_indexing_task.IndexingRunner", MagicMock(return_value=MagicMock()))
|
||||
|
||||
delay_mock = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task.generate_summary_index_task.delay", delay_mock)
|
||||
|
||||
# Act
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
delay_mock.assert_not_called()
|
||||
|
||||
def test_normal_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test normal indexing task delegates to tenant queue handler."""
|
||||
# Arrange
|
||||
handler = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
|
||||
|
||||
# Act
|
||||
normal_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], normal_document_indexing_task)
|
||||
|
||||
def test_priority_document_indexing_task_should_delegate(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test priority indexing task delegates to tenant queue handler."""
|
||||
# Arrange
|
||||
handler = MagicMock()
|
||||
monkeypatch.setattr("tasks.document_indexing_task._document_indexing_with_tenant_queue", handler)
|
||||
|
||||
# Act
|
||||
priority_document_indexing_task("tenant-1", "dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
handler.assert_called_once_with("tenant-1", "dataset-1", ["doc-1"], priority_document_indexing_task)
|
||||
|
||||
Reference in New Issue
Block a user