diff --git a/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py b/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py new file mode 100644 index 00000000000..60002a757d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py @@ -0,0 +1,24 @@ +from pytest_mock import MockerFixture + +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter + + +def test_format_returns_result_value_as_string(mocker: MockerFixture) -> None: + execute_mock = mocker.patch( + "core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template", + return_value={"result": 123}, + ) + + formatted = Jinja2Formatter.format("Hello {{ name }}", {"name": "Dify"}) + + assert formatted == "123" + execute_mock.assert_called_once() + + +def test_format_returns_empty_string_when_result_missing(mocker: MockerFixture) -> None: + mocker.patch( + "core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template", + return_value={}, + ) + + assert Jinja2Formatter.format("Hello", {"name": "Dify"}) == "" diff --git a/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py b/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py new file mode 100644 index 00000000000..e09dd03489d --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.helper.code_executor import code_executor as code_executor_module + + +def test_execute_workflow_code_template_raises_for_unsupported_language() -> None: + with pytest.raises(code_executor_module.CodeExecutionError, match="Unsupported language"): + code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "ruby"), "print(1)", {}) + + +def test_execute_workflow_code_template_uses_transformer(mocker: MockerFixture) -> None: + transformer = MagicMock() + transformer.transform_caller.return_value = ("runner-script", "preload-script") + transformer.transform_response.return_value = {"result": "ok"} + execute_mock = mocker.patch.object( + code_executor_module.CodeExecutor, + "execute_code", + return_value='<>{"result":"ok"}<>', + ) + mocker.patch.dict(code_executor_module.CodeExecutor.code_template_transformers, {"fake": transformer}, clear=False) + + result = code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "fake"), "code", {"a": 1}) + + assert result == {"result": "ok"} + transformer.transform_caller.assert_called_once_with("code", {"a": 1}) + execute_mock.assert_called_once_with("fake", "preload-script", "runner-script") + + +def test_execute_code_raises_service_unavailable_for_503(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 503 + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="service is unavailable"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_returns_stdout_on_success(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"code": 0, "message": "ok", "data": {"stdout": "done", "error": None}} + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + assert code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") == "done" + + +def test_execute_code_raises_for_non_200_status(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 500 + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="likely a network issue"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_client_post_fails(mocker: MockerFixture) -> None: + client = MagicMock() + client.post.side_effect = RuntimeError("timeout") + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="likely a network issue"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_response_json_is_invalid(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.side_effect = ValueError("bad json") + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="Failed to parse response"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_sandbox_returns_error_code(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"code": 1, "message": "boom", "data": {"stdout": "", "error": None}} + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="Got error code: 1"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_response_contains_runtime_error(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"code": 0, "message": "ok", "data": {"stdout": "", "error": "runtime failed"}} + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="runtime failed"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") diff --git a/api/tests/unit_tests/core/helper/code_executor/test_code_node_provider.py b/api/tests/unit_tests/core/helper/code_executor/test_code_node_provider.py new file mode 100644 index 00000000000..47761a32ac8 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_code_node_provider.py @@ -0,0 +1,29 @@ +from core.helper.code_executor.code_node_provider import CodeNodeProvider + + +class _DummyProvider(CodeNodeProvider): + @staticmethod + def get_language() -> str: + return "dummy" + + @classmethod + def get_default_code(cls) -> str: + return "def main():\n return {'result': 'ok'}" + + +def test_is_accept_language() -> None: + assert _DummyProvider.is_accept_language("dummy") is True + assert _DummyProvider.is_accept_language("other") is False + + +def test_get_default_config_contains_expected_shape() -> None: + config = _DummyProvider.get_default_config() + + assert config["type"] == "code" + assert config["config"]["code_language"] == "dummy" + assert config["config"]["code"] == _DummyProvider.get_default_code() + assert config["config"]["variables"] == [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ] + assert config["config"]["outputs"] == {"result": {"type": "string", "children": None}} diff --git a/api/tests/unit_tests/core/helper/code_executor/test_template_transformer.py b/api/tests/unit_tests/core/helper/code_executor/test_template_transformer.py new file mode 100644 index 00000000000..5b54b8e6474 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_template_transformer.py @@ -0,0 +1,81 @@ +import json +from base64 import b64decode +from collections.abc import Mapping +from typing import Any + +import pytest + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class _DummyTransformer(TemplateTransformer): + @classmethod + def get_runner_script(cls) -> str: + return f"CODE={cls._code_placeholder};INPUTS={cls._inputs_placeholder}" + + +def test_serialize_code_encodes_to_base64() -> None: + encoded = _DummyTransformer.serialize_code("print('hi')") + + assert b64decode(encoded.encode()).decode() == "print('hi')" + + +def test_assemble_runner_script_embeds_code_and_inputs() -> None: + script = _DummyTransformer.assemble_runner_script("x = 1", {"a": "b"}) + + assert "CODE=x = 1" in script + payload = script.split("INPUTS=", maxsplit=1)[1] + assert json.loads(b64decode(payload.encode()).decode()) == {"a": "b"} + + +def test_transform_caller_returns_runner_and_empty_preload() -> None: + runner, preload = _DummyTransformer.transform_caller("x = 2", {"k": "v"}) + + assert "CODE=x = 2" in runner + assert preload == "" + + +def test_serialize_inputs_encodes_payload() -> None: + payload = _DummyTransformer.serialize_inputs({"foo": "bar"}) + + assert json.loads(b64decode(payload.encode()).decode()) == {"foo": "bar"} + + +def test_transform_response_parses_json_result_and_converts_scientific_notation() -> None: + response = '<>{"value": "1e+3", "nested": {"x": "2E-2"}, "arr": ["3e+1"]}<>' + + result: Mapping[str, Any] = _DummyTransformer.transform_response(response) + + assert result == {"value": 1000.0, "nested": {"x": 0.02}, "arr": [30.0]} + + +def test_transform_response_raises_for_invalid_json() -> None: + with pytest.raises(ValueError, match="Failed to parse JSON response"): + _DummyTransformer.transform_response("<>{invalid json}<>") + + +def test_transform_response_raises_for_non_dict_result() -> None: + with pytest.raises(ValueError, match="Result must be a dict"): + _DummyTransformer.transform_response("<>[1,2,3]<>") + + +def test_transform_response_raises_for_non_string_keys(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("json.loads", lambda _: {1: "x"}) + + with pytest.raises(ValueError, match="Result keys must be strings"): + _DummyTransformer.transform_response('<>{"ignored": true}<>') + + +def test_transform_response_raises_for_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None: + def _raise_unexpected(_: str) -> Any: + raise RuntimeError("boom") + + monkeypatch.setattr("json.loads", _raise_unexpected) + + with pytest.raises(ValueError, match="Unexpected error during response transformation"): + _DummyTransformer.transform_response('<>{"a":1}<>') + + +def test_transform_response_raises_for_missing_result_tag() -> None: + with pytest.raises(ValueError, match="no result tag found"): + _DummyTransformer.transform_response("plain output") diff --git a/api/tests/unit_tests/core/helper/test_credential_utils.py b/api/tests/unit_tests/core/helper/test_credential_utils.py new file mode 100644 index 00000000000..7e0d7d0af70 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_credential_utils.py @@ -0,0 +1,138 @@ +from types import SimpleNamespace +from typing import cast + +import pytest +from pytest_mock import MockerFixture + +from core.helper.credential_utils import check_credential_policy_compliance, is_credential_exists +from services.enterprise.plugin_manager_service import PluginCredentialType + + +def test_check_credential_policy_compliance_returns_when_feature_disabled( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False)), + ) + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.MODEL) + + check_call.assert_not_called() + + +def test_check_credential_policy_compliance_raises_when_credential_missing( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + mocker.patch("core.helper.credential_utils.is_credential_exists", return_value=False) + + with pytest.raises(ValueError, match="Credential with id cred-1 for provider openai not found."): + check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.TOOL) + + +def test_check_credential_policy_compliance_calls_plugin_manager_with_request( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + mocker.patch("core.helper.credential_utils.is_credential_exists", return_value=True) + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.MODEL) + + check_call.assert_called_once() + request_arg = check_call.call_args.args[0] + assert request_arg.dify_credential_id == "cred-1" + assert request_arg.provider == "openai" + assert request_arg.credential_type == PluginCredentialType.MODEL + + +def test_check_credential_policy_compliance_skips_existence_check_when_disabled( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + exists_call = mocker.patch("core.helper.credential_utils.is_credential_exists") + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance( + credential_id="cred-1", + provider="openai", + credential_type=PluginCredentialType.MODEL, + check_existence=False, + ) + + exists_call.assert_not_called() + check_call.assert_called_once() + + +def test_check_credential_policy_compliance_returns_when_credential_id_empty( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + exists_call = mocker.patch("core.helper.credential_utils.is_credential_exists") + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance("", "openai", PluginCredentialType.MODEL) + + exists_call.assert_not_called() + check_call.assert_not_called() + + +@pytest.mark.parametrize( + ("credential_type", "scalar_result", "expected"), + [ + (PluginCredentialType.MODEL, "model-credential", True), + (PluginCredentialType.MODEL, None, False), + (PluginCredentialType.TOOL, "tool-credential", True), + (PluginCredentialType.TOOL, None, False), + ], +) +def test_is_credential_exists_by_type( + mocker: MockerFixture, + credential_type: PluginCredentialType, + scalar_result: str | None, + expected: bool, +) -> None: + mocker.patch("extensions.ext_database.db", new=SimpleNamespace(engine=object())) + session_cls = mocker.patch("sqlalchemy.orm.Session") + session = session_cls.return_value.__enter__.return_value + session.scalar.return_value = scalar_result + + result = is_credential_exists("cred-1", credential_type) + + assert result is expected + session.scalar.assert_called_once() + + +def test_is_credential_exists_returns_false_for_unknown_type( + mocker: MockerFixture, +) -> None: + mocker.patch("extensions.ext_database.db", new=SimpleNamespace(engine=object())) + session_cls = mocker.patch("sqlalchemy.orm.Session") + session = session_cls.return_value.__enter__.return_value + + result = is_credential_exists("cred-1", cast(PluginCredentialType, "unknown")) + + assert result is False + session.scalar.assert_not_called() diff --git a/api/tests/unit_tests/core/helper/test_download.py b/api/tests/unit_tests/core/helper/test_download.py new file mode 100644 index 00000000000..0755c25826b --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_download.py @@ -0,0 +1,53 @@ +from collections.abc import Iterator + +import pytest +from pytest_mock import MockerFixture + +from core.helper.download import download_with_size_limit + + +class _StubResponse: + def __init__(self, status_code: int, chunks: list[bytes]) -> None: + self.status_code = status_code + self._chunks = chunks + + def iter_bytes(self) -> Iterator[bytes]: + return iter(self._chunks) + + +def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None: + response = _StubResponse(status_code=200, chunks=[b"ab", b"cd", b"ef"]) + mock_get = mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + + content = download_with_size_limit("https://example.com/a.txt", max_download_size=6, timeout=10) + + assert content == b"abcdef" + mock_get.assert_called_once_with("https://example.com/a.txt", follow_redirects=True, timeout=10) + + +def test_download_with_size_limit_raises_for_404(mocker: MockerFixture) -> None: + mocker.patch("core.helper.download.ssrf_proxy.get", return_value=_StubResponse(status_code=404, chunks=[])) + + with pytest.raises(ValueError, match="file not found"): + download_with_size_limit("https://example.com/missing.txt", max_download_size=10) + + +def test_download_with_size_limit_raises_when_size_exceeds_limit( + mocker: MockerFixture, +) -> None: + response = _StubResponse(status_code=200, chunks=[b"abc", b"de"]) + mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + + with pytest.raises(ValueError, match="Max file size reached"): + download_with_size_limit("https://example.com/large.bin", max_download_size=4) + + +def test_download_with_size_limit_accepts_content_equal_to_limit( + mocker: MockerFixture, +) -> None: + response = _StubResponse(status_code=200, chunks=[b"ab", b"cd"]) + mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + + content = download_with_size_limit("https://example.com/exact.bin", max_download_size=4) + + assert content == b"abcd" diff --git a/api/tests/unit_tests/core/helper/test_http_client_pooling.py b/api/tests/unit_tests/core/helper/test_http_client_pooling.py new file mode 100644 index 00000000000..c29962f1b1a --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_http_client_pooling.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import httpx + +from core.helper.http_client_pooling import HttpClientPoolFactory + + +def test_get_or_create_reuses_client_for_same_key() -> None: + factory = HttpClientPoolFactory() + first_client = MagicMock(spec=httpx.Client) + second_client = MagicMock(spec=httpx.Client) + clients = [first_client, second_client] + + def _builder() -> httpx.Client: + return clients.pop(0) + + assert factory.get_or_create("shared", _builder) is first_client + assert factory.get_or_create("shared", _builder) is first_client + + +def test_get_or_create_creates_distinct_clients_for_distinct_keys() -> None: + factory = HttpClientPoolFactory() + client_a = MagicMock(spec=httpx.Client) + client_b = MagicMock(spec=httpx.Client) + + assert factory.get_or_create("a", lambda: client_a) is client_a + assert factory.get_or_create("b", lambda: client_b) is client_b + + +def test_close_all_closes_pooled_clients_and_allows_recreate() -> None: + factory = HttpClientPoolFactory() + first_client = MagicMock(spec=httpx.Client) + replacement_client = MagicMock(spec=httpx.Client) + + assert factory.get_or_create("x", lambda: first_client) is first_client + factory.close_all() + + first_client.close.assert_called_once() + assert factory.get_or_create("x", lambda: replacement_client) is replacement_client diff --git a/api/tests/unit_tests/core/helper/test_marketplace.py b/api/tests/unit_tests/core/helper/test_marketplace.py new file mode 100644 index 00000000000..bd561b16373 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_marketplace.py @@ -0,0 +1,110 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from pytest_mock import MockerFixture + +from core.helper.marketplace import ( + batch_fetch_plugin_by_ids, + batch_fetch_plugin_manifests, + download_plugin_pkg, + fetch_global_plugin_manifest, + get_plugin_pkg_url, + record_install_plugin_event, +) + + +def test_get_plugin_pkg_url_contains_unique_identifier() -> None: + url = get_plugin_pkg_url("plugin@1.0.0") + + assert "api/v1/plugins/download" in url + assert "unique_identifier=plugin@1.0.0" in url + + +def test_download_plugin_pkg_delegates_with_configured_size(mocker: MockerFixture) -> None: + mocked_download = mocker.patch("core.helper.marketplace.download_with_size_limit", return_value=b"pkg") + mocker.patch("core.helper.marketplace.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 1234) + + result = download_plugin_pkg("plugin.a.b") + + assert result == b"pkg" + mocked_download.assert_called_once() + called_url, called_limit = mocked_download.call_args.args + assert "unique_identifier=plugin.a.b" in called_url + assert called_limit == 1234 + + +def test_batch_fetch_plugin_by_ids_returns_empty_for_empty_input(mocker: MockerFixture) -> None: + post_mock = mocker.patch("core.helper.marketplace.httpx.post") + + assert batch_fetch_plugin_by_ids([]) == [] + post_mock.assert_not_called() + + +def test_batch_fetch_plugin_by_ids_returns_plugins_from_response(mocker: MockerFixture) -> None: + response = MagicMock() + response.json.return_value = {"data": {"plugins": [{"id": "p1"}]}} + response.raise_for_status.return_value = None + post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response) + + plugins = batch_fetch_plugin_by_ids(["p1"]) + + assert plugins == [{"id": "p1"}] + post_mock.assert_called_once() + response.raise_for_status.assert_called_once() + + +def test_batch_fetch_plugin_manifests_returns_empty_for_empty_input(mocker: MockerFixture) -> None: + post_mock = mocker.patch("core.helper.marketplace.httpx.post") + + assert batch_fetch_plugin_manifests([]) == [] + post_mock.assert_not_called() + + +def test_batch_fetch_plugin_manifests_validates_and_returns_plugins(mocker: MockerFixture) -> None: + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"data": {"plugins": [{"id": "p1"}, {"id": "p2"}]}} + post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response) + validate_mock = mocker.patch( + "core.helper.marketplace.MarketplacePluginDeclaration.model_validate", + side_effect=["manifest-1", "manifest-2"], + ) + + result = batch_fetch_plugin_manifests(["p1", "p2"]) + + assert result == ["manifest-1", "manifest-2"] + post_mock.assert_called_once() + assert validate_mock.call_count == 2 + response.raise_for_status.assert_called_once() + + +def test_record_install_plugin_event_posts_and_checks_status(mocker: MockerFixture) -> None: + response = MagicMock() + response.raise_for_status.return_value = None + post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response) + + record_install_plugin_event("plugin.a") + + post_mock.assert_called_once() + response.raise_for_status.assert_called_once() + + +def test_fetch_global_plugin_manifest_caches_each_plugin(mocker: MockerFixture) -> None: + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"plugins": [{"id": "a"}, {"id": "b"}]} + mocker.patch("core.helper.marketplace.httpx.get", return_value=response) + + snapshot_a = SimpleNamespace(plugin_id="plugin-a", model_dump_json=lambda: '{"id":"a"}') + snapshot_b = SimpleNamespace(plugin_id="plugin-b", model_dump_json=lambda: '{"id":"b"}') + validate_mock = mocker.patch( + "core.helper.marketplace.MarketplacePluginSnapshot.model_validate", + side_effect=[snapshot_a, snapshot_b], + ) + setex_mock = mocker.patch("core.helper.marketplace.redis_client.setex") + + fetch_global_plugin_manifest("prefix:", 60) + + assert validate_mock.call_count == 2 + setex_mock.assert_any_call(name="prefix:plugin-a", time=60, value='{"id":"a"}') + setex_mock.assert_any_call(name="prefix:plugin-b", time=60, value='{"id":"b"}') diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py new file mode 100644 index 00000000000..4a84099b74f --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -0,0 +1,158 @@ +from types import SimpleNamespace +from typing import cast + +import pytest +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from pytest_mock import MockerFixture + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.moderation import check_moderation +from models.provider import ProviderType + + +def _build_model_config(provider: str = "openai") -> SimpleNamespace: + return SimpleNamespace( + provider=provider, + provider_model_bundle=SimpleNamespace( + configuration=SimpleNamespace(using_provider_type=ProviderType.SYSTEM), + ), + ) + + +def test_check_moderation_returns_false_when_feature_not_enabled(mocker: MockerFixture) -> None: + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace(moderation_config=None, provider_map={}), + ) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "hello", + ) + is False + ) + + +def test_check_moderation_returns_false_when_hosting_credentials_missing(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: SimpleNamespace(enabled=True, credentials=None)}, + ), + ) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "hello", + ) + is False + ) + + +def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") + + moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk") + factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "abc", + ) + is True + ) + + +def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory") + choice_mock = mocker.patch("core.helper.moderation.secrets.choice") + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "", + ) + is True + ) + factory_mock.assert_not_called() + choice_mock.assert_not_called() + + +def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") + + moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False) + factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "abc", + ) + is False + ) + + +def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") + + failing_model = SimpleNamespace( + invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")), + ) + factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model) + mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + + with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."): + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "abc", + ) diff --git a/api/tests/unit_tests/core/helper/test_name_generator.py b/api/tests/unit_tests/core/helper/test_name_generator.py new file mode 100644 index 00000000000..37a87260f15 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_name_generator.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +from pytest_mock import MockerFixture + +from core.helper.name_generator import generate_incremental_name, generate_provider_name +from core.plugin.entities.plugin_daemon import CredentialType + + +@dataclass +class _Provider: + name: str + + +def test_generate_incremental_name_uses_next_highest_suffix() -> None: + names = ["API KEY 1", "API KEY 3", "API KEY 2", "other", "", "API KEY x"] + + assert generate_incremental_name(names, "API KEY") == "API KEY 4" + + +def test_generate_incremental_name_returns_default_when_no_matches() -> None: + assert generate_incremental_name(["custom", " ", ""], "AUTH") == "AUTH 1" + + +def test_generate_provider_name_uses_credential_display_name() -> None: + providers = [_Provider(name="API KEY 1"), _Provider(name="API KEY 2")] + + assert generate_provider_name(providers, CredentialType.API_KEY) == "API KEY 3" + + +def test_generate_provider_name_falls_back_on_generation_error(mocker: MockerFixture) -> None: + mocker.patch("core.helper.name_generator.generate_incremental_name", side_effect=RuntimeError("boom")) + + assert generate_provider_name([], CredentialType.OAUTH2, fallback_context="ctx") == "AUTH 1" diff --git a/api/tests/unit_tests/core/helper/test_tool_parameter_cache.py b/api/tests/unit_tests/core/helper/test_tool_parameter_cache.py new file mode 100644 index 00000000000..3c8b44d0101 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_tool_parameter_cache.py @@ -0,0 +1,71 @@ +import json + +from pytest_mock import MockerFixture + +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType + + +def test_tool_parameter_cache_get_returns_decoded_dict(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + payload = {"k": "v", "n": 1} + cache_key = cache.cache_key + + redis_client_mock.get.return_value = json.dumps(payload).encode("utf-8") + + assert cache.get() == payload + redis_client_mock.get.assert_called_once_with(cache_key) + + +def test_tool_parameter_cache_get_returns_none_for_invalid_json(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + + redis_client_mock.get.return_value = b"{invalid-json" + + assert cache.get() is None + + +def test_tool_parameter_cache_get_returns_none_when_key_is_missing(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + + redis_client_mock.get.return_value = None + + assert cache.get() is None + + +def test_tool_parameter_cache_set_and_delete(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + + params = {"a": "b"} + cache.set(params) + cache.delete() + + redis_client_mock.setex.assert_called_once_with(cache.cache_key, 86400, json.dumps(params)) + redis_client_mock.delete.assert_called_once_with(cache.cache_key)