mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 05:09:19 +08:00
test: added unit test for remaining files in core helper folder (#33288)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com> Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,24 @@
|
|||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_returns_result_value_as_string(mocker: MockerFixture) -> None:
|
||||||
|
execute_mock = mocker.patch(
|
||||||
|
"core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template",
|
||||||
|
return_value={"result": 123},
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = Jinja2Formatter.format("Hello {{ name }}", {"name": "Dify"})
|
||||||
|
|
||||||
|
assert formatted == "123"
|
||||||
|
execute_mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_returns_empty_string_when_result_missing(mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template",
|
||||||
|
return_value={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert Jinja2Formatter.format("Hello", {"name": "Dify"}) == ""
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.code_executor import code_executor as code_executor_module
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_workflow_code_template_raises_for_unsupported_language() -> None:
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="Unsupported language"):
|
||||||
|
code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "ruby"), "print(1)", {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_workflow_code_template_uses_transformer(mocker: MockerFixture) -> None:
|
||||||
|
transformer = MagicMock()
|
||||||
|
transformer.transform_caller.return_value = ("runner-script", "preload-script")
|
||||||
|
transformer.transform_response.return_value = {"result": "ok"}
|
||||||
|
execute_mock = mocker.patch.object(
|
||||||
|
code_executor_module.CodeExecutor,
|
||||||
|
"execute_code",
|
||||||
|
return_value='<<RESULT>>{"result":"ok"}<<RESULT>>',
|
||||||
|
)
|
||||||
|
mocker.patch.dict(code_executor_module.CodeExecutor.code_template_transformers, {"fake": transformer}, clear=False)
|
||||||
|
|
||||||
|
result = code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "fake"), "code", {"a": 1})
|
||||||
|
|
||||||
|
assert result == {"result": "ok"}
|
||||||
|
transformer.transform_caller.assert_called_once_with("code", {"a": 1})
|
||||||
|
execute_mock.assert_called_once_with("fake", "preload-script", "runner-script")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_raises_service_unavailable_for_503(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 503
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.return_value = response
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="service is unavailable"):
|
||||||
|
code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_returns_stdout_on_success(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.json.return_value = {"code": 0, "message": "ok", "data": {"stdout": "done", "error": None}}
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.return_value = response
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
assert code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") == "done"
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_raises_for_non_200_status(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 500
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.return_value = response
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="likely a network issue"):
|
||||||
|
code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_raises_when_client_post_fails(mocker: MockerFixture) -> None:
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.side_effect = RuntimeError("timeout")
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="likely a network issue"):
|
||||||
|
code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_raises_when_response_json_is_invalid(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.json.side_effect = ValueError("bad json")
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.return_value = response
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="Failed to parse response"):
|
||||||
|
code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_raises_when_sandbox_returns_error_code(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.json.return_value = {"code": 1, "message": "boom", "data": {"stdout": "", "error": None}}
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.return_value = response
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="Got error code: 1"):
|
||||||
|
code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_code_raises_when_response_contains_runtime_error(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.json.return_value = {"code": 0, "message": "ok", "data": {"stdout": "", "error": "runtime failed"}}
|
||||||
|
client = MagicMock()
|
||||||
|
client.post.return_value = response
|
||||||
|
mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client)
|
||||||
|
|
||||||
|
with pytest.raises(code_executor_module.CodeExecutionError, match="runtime failed"):
|
||||||
|
code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)")
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyProvider(CodeNodeProvider):
|
||||||
|
@staticmethod
|
||||||
|
def get_language() -> str:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_code(cls) -> str:
|
||||||
|
return "def main():\n return {'result': 'ok'}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_accept_language() -> None:
|
||||||
|
assert _DummyProvider.is_accept_language("dummy") is True
|
||||||
|
assert _DummyProvider.is_accept_language("other") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_config_contains_expected_shape() -> None:
|
||||||
|
config = _DummyProvider.get_default_config()
|
||||||
|
|
||||||
|
assert config["type"] == "code"
|
||||||
|
assert config["config"]["code_language"] == "dummy"
|
||||||
|
assert config["config"]["code"] == _DummyProvider.get_default_code()
|
||||||
|
assert config["config"]["variables"] == [
|
||||||
|
{"variable": "arg1", "value_selector": []},
|
||||||
|
{"variable": "arg2", "value_selector": []},
|
||||||
|
]
|
||||||
|
assert config["config"]["outputs"] == {"result": {"type": "string", "children": None}}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
import json
|
||||||
|
from base64 import b64decode
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyTransformer(TemplateTransformer):
|
||||||
|
@classmethod
|
||||||
|
def get_runner_script(cls) -> str:
|
||||||
|
return f"CODE={cls._code_placeholder};INPUTS={cls._inputs_placeholder}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_code_encodes_to_base64() -> None:
|
||||||
|
encoded = _DummyTransformer.serialize_code("print('hi')")
|
||||||
|
|
||||||
|
assert b64decode(encoded.encode()).decode() == "print('hi')"
|
||||||
|
|
||||||
|
|
||||||
|
def test_assemble_runner_script_embeds_code_and_inputs() -> None:
|
||||||
|
script = _DummyTransformer.assemble_runner_script("x = 1", {"a": "b"})
|
||||||
|
|
||||||
|
assert "CODE=x = 1" in script
|
||||||
|
payload = script.split("INPUTS=", maxsplit=1)[1]
|
||||||
|
assert json.loads(b64decode(payload.encode()).decode()) == {"a": "b"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_caller_returns_runner_and_empty_preload() -> None:
|
||||||
|
runner, preload = _DummyTransformer.transform_caller("x = 2", {"k": "v"})
|
||||||
|
|
||||||
|
assert "CODE=x = 2" in runner
|
||||||
|
assert preload == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_inputs_encodes_payload() -> None:
|
||||||
|
payload = _DummyTransformer.serialize_inputs({"foo": "bar"})
|
||||||
|
|
||||||
|
assert json.loads(b64decode(payload.encode()).decode()) == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_response_parses_json_result_and_converts_scientific_notation() -> None:
|
||||||
|
response = '<<RESULT>>{"value": "1e+3", "nested": {"x": "2E-2"}, "arr": ["3e+1"]}<<RESULT>>'
|
||||||
|
|
||||||
|
result: Mapping[str, Any] = _DummyTransformer.transform_response(response)
|
||||||
|
|
||||||
|
assert result == {"value": 1000.0, "nested": {"x": 0.02}, "arr": [30.0]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_response_raises_for_invalid_json() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Failed to parse JSON response"):
|
||||||
|
_DummyTransformer.transform_response("<<RESULT>>{invalid json}<<RESULT>>")
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_response_raises_for_non_dict_result() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Result must be a dict"):
|
||||||
|
_DummyTransformer.transform_response("<<RESULT>>[1,2,3]<<RESULT>>")
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_response_raises_for_non_string_keys(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr("json.loads", lambda _: {1: "x"})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Result keys must be strings"):
|
||||||
|
_DummyTransformer.transform_response('<<RESULT>>{"ignored": true}<<RESULT>>')
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_response_raises_for_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _raise_unexpected(_: str) -> Any:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
monkeypatch.setattr("json.loads", _raise_unexpected)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unexpected error during response transformation"):
|
||||||
|
_DummyTransformer.transform_response('<<RESULT>>{"a":1}<<RESULT>>')
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_response_raises_for_missing_result_tag() -> None:
|
||||||
|
with pytest.raises(ValueError, match="no result tag found"):
|
||||||
|
_DummyTransformer.transform_response("plain output")
|
||||||
138
api/tests/unit_tests/core/helper/test_credential_utils.py
Normal file
138
api/tests/unit_tests/core/helper/test_credential_utils.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.credential_utils import check_credential_policy_compliance, is_credential_exists
|
||||||
|
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_credential_policy_compliance_returns_when_feature_disabled(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"services.feature_service.FeatureService.get_system_features",
|
||||||
|
return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False)),
|
||||||
|
)
|
||||||
|
check_call = mocker.patch(
|
||||||
|
"services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance"
|
||||||
|
)
|
||||||
|
|
||||||
|
check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.MODEL)
|
||||||
|
|
||||||
|
check_call.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_credential_policy_compliance_raises_when_credential_missing(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"services.feature_service.FeatureService.get_system_features",
|
||||||
|
return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)),
|
||||||
|
)
|
||||||
|
mocker.patch("core.helper.credential_utils.is_credential_exists", return_value=False)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Credential with id cred-1 for provider openai not found."):
|
||||||
|
check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.TOOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_credential_policy_compliance_calls_plugin_manager_with_request(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"services.feature_service.FeatureService.get_system_features",
|
||||||
|
return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)),
|
||||||
|
)
|
||||||
|
mocker.patch("core.helper.credential_utils.is_credential_exists", return_value=True)
|
||||||
|
check_call = mocker.patch(
|
||||||
|
"services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance"
|
||||||
|
)
|
||||||
|
|
||||||
|
check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.MODEL)
|
||||||
|
|
||||||
|
check_call.assert_called_once()
|
||||||
|
request_arg = check_call.call_args.args[0]
|
||||||
|
assert request_arg.dify_credential_id == "cred-1"
|
||||||
|
assert request_arg.provider == "openai"
|
||||||
|
assert request_arg.credential_type == PluginCredentialType.MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_credential_policy_compliance_skips_existence_check_when_disabled(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"services.feature_service.FeatureService.get_system_features",
|
||||||
|
return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)),
|
||||||
|
)
|
||||||
|
exists_call = mocker.patch("core.helper.credential_utils.is_credential_exists")
|
||||||
|
check_call = mocker.patch(
|
||||||
|
"services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance"
|
||||||
|
)
|
||||||
|
|
||||||
|
check_credential_policy_compliance(
|
||||||
|
credential_id="cred-1",
|
||||||
|
provider="openai",
|
||||||
|
credential_type=PluginCredentialType.MODEL,
|
||||||
|
check_existence=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
exists_call.assert_not_called()
|
||||||
|
check_call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_credential_policy_compliance_returns_when_credential_id_empty(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"services.feature_service.FeatureService.get_system_features",
|
||||||
|
return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)),
|
||||||
|
)
|
||||||
|
exists_call = mocker.patch("core.helper.credential_utils.is_credential_exists")
|
||||||
|
check_call = mocker.patch(
|
||||||
|
"services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance"
|
||||||
|
)
|
||||||
|
|
||||||
|
check_credential_policy_compliance("", "openai", PluginCredentialType.MODEL)
|
||||||
|
|
||||||
|
exists_call.assert_not_called()
|
||||||
|
check_call.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("credential_type", "scalar_result", "expected"),
|
||||||
|
[
|
||||||
|
(PluginCredentialType.MODEL, "model-credential", True),
|
||||||
|
(PluginCredentialType.MODEL, None, False),
|
||||||
|
(PluginCredentialType.TOOL, "tool-credential", True),
|
||||||
|
(PluginCredentialType.TOOL, None, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_credential_exists_by_type(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
credential_type: PluginCredentialType,
|
||||||
|
scalar_result: str | None,
|
||||||
|
expected: bool,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch("extensions.ext_database.db", new=SimpleNamespace(engine=object()))
|
||||||
|
session_cls = mocker.patch("sqlalchemy.orm.Session")
|
||||||
|
session = session_cls.return_value.__enter__.return_value
|
||||||
|
session.scalar.return_value = scalar_result
|
||||||
|
|
||||||
|
result = is_credential_exists("cred-1", credential_type)
|
||||||
|
|
||||||
|
assert result is expected
|
||||||
|
session.scalar.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_credential_exists_returns_false_for_unknown_type(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch("extensions.ext_database.db", new=SimpleNamespace(engine=object()))
|
||||||
|
session_cls = mocker.patch("sqlalchemy.orm.Session")
|
||||||
|
session = session_cls.return_value.__enter__.return_value
|
||||||
|
|
||||||
|
result = is_credential_exists("cred-1", cast(PluginCredentialType, "unknown"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
session.scalar.assert_not_called()
|
||||||
53
api/tests/unit_tests/core/helper/test_download.py
Normal file
53
api/tests/unit_tests/core/helper/test_download.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.download import download_with_size_limit
|
||||||
|
|
||||||
|
|
||||||
|
class _StubResponse:
|
||||||
|
def __init__(self, status_code: int, chunks: list[bytes]) -> None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self._chunks = chunks
|
||||||
|
|
||||||
|
def iter_bytes(self) -> Iterator[bytes]:
|
||||||
|
return iter(self._chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None:
|
||||||
|
response = _StubResponse(status_code=200, chunks=[b"ab", b"cd", b"ef"])
|
||||||
|
mock_get = mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
|
||||||
|
|
||||||
|
content = download_with_size_limit("https://example.com/a.txt", max_download_size=6, timeout=10)
|
||||||
|
|
||||||
|
assert content == b"abcdef"
|
||||||
|
mock_get.assert_called_once_with("https://example.com/a.txt", follow_redirects=True, timeout=10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_with_size_limit_raises_for_404(mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=_StubResponse(status_code=404, chunks=[]))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="file not found"):
|
||||||
|
download_with_size_limit("https://example.com/missing.txt", max_download_size=10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_with_size_limit_raises_when_size_exceeds_limit(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
response = _StubResponse(status_code=200, chunks=[b"abc", b"de"])
|
||||||
|
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Max file size reached"):
|
||||||
|
download_with_size_limit("https://example.com/large.bin", max_download_size=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_with_size_limit_accepts_content_equal_to_limit(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
response = _StubResponse(status_code=200, chunks=[b"ab", b"cd"])
|
||||||
|
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
|
||||||
|
|
||||||
|
content = download_with_size_limit("https://example.com/exact.bin", max_download_size=4)
|
||||||
|
|
||||||
|
assert content == b"abcd"
|
||||||
41
api/tests/unit_tests/core/helper/test_http_client_pooling.py
Normal file
41
api/tests/unit_tests/core/helper/test_http_client_pooling.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from core.helper.http_client_pooling import HttpClientPoolFactory
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_create_reuses_client_for_same_key() -> None:
|
||||||
|
factory = HttpClientPoolFactory()
|
||||||
|
first_client = MagicMock(spec=httpx.Client)
|
||||||
|
second_client = MagicMock(spec=httpx.Client)
|
||||||
|
clients = [first_client, second_client]
|
||||||
|
|
||||||
|
def _builder() -> httpx.Client:
|
||||||
|
return clients.pop(0)
|
||||||
|
|
||||||
|
assert factory.get_or_create("shared", _builder) is first_client
|
||||||
|
assert factory.get_or_create("shared", _builder) is first_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_create_creates_distinct_clients_for_distinct_keys() -> None:
|
||||||
|
factory = HttpClientPoolFactory()
|
||||||
|
client_a = MagicMock(spec=httpx.Client)
|
||||||
|
client_b = MagicMock(spec=httpx.Client)
|
||||||
|
|
||||||
|
assert factory.get_or_create("a", lambda: client_a) is client_a
|
||||||
|
assert factory.get_or_create("b", lambda: client_b) is client_b
|
||||||
|
|
||||||
|
|
||||||
|
def test_close_all_closes_pooled_clients_and_allows_recreate() -> None:
|
||||||
|
factory = HttpClientPoolFactory()
|
||||||
|
first_client = MagicMock(spec=httpx.Client)
|
||||||
|
replacement_client = MagicMock(spec=httpx.Client)
|
||||||
|
|
||||||
|
assert factory.get_or_create("x", lambda: first_client) is first_client
|
||||||
|
factory.close_all()
|
||||||
|
|
||||||
|
first_client.close.assert_called_once()
|
||||||
|
assert factory.get_or_create("x", lambda: replacement_client) is replacement_client
|
||||||
110
api/tests/unit_tests/core/helper/test_marketplace.py
Normal file
110
api/tests/unit_tests/core/helper/test_marketplace.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.marketplace import (
|
||||||
|
batch_fetch_plugin_by_ids,
|
||||||
|
batch_fetch_plugin_manifests,
|
||||||
|
download_plugin_pkg,
|
||||||
|
fetch_global_plugin_manifest,
|
||||||
|
get_plugin_pkg_url,
|
||||||
|
record_install_plugin_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_plugin_pkg_url_contains_unique_identifier() -> None:
|
||||||
|
url = get_plugin_pkg_url("plugin@1.0.0")
|
||||||
|
|
||||||
|
assert "api/v1/plugins/download" in url
|
||||||
|
assert "unique_identifier=plugin@1.0.0" in url
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_plugin_pkg_delegates_with_configured_size(mocker: MockerFixture) -> None:
|
||||||
|
mocked_download = mocker.patch("core.helper.marketplace.download_with_size_limit", return_value=b"pkg")
|
||||||
|
mocker.patch("core.helper.marketplace.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 1234)
|
||||||
|
|
||||||
|
result = download_plugin_pkg("plugin.a.b")
|
||||||
|
|
||||||
|
assert result == b"pkg"
|
||||||
|
mocked_download.assert_called_once()
|
||||||
|
called_url, called_limit = mocked_download.call_args.args
|
||||||
|
assert "unique_identifier=plugin.a.b" in called_url
|
||||||
|
assert called_limit == 1234
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_fetch_plugin_by_ids_returns_empty_for_empty_input(mocker: MockerFixture) -> None:
|
||||||
|
post_mock = mocker.patch("core.helper.marketplace.httpx.post")
|
||||||
|
|
||||||
|
assert batch_fetch_plugin_by_ids([]) == []
|
||||||
|
post_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_fetch_plugin_by_ids_returns_plugins_from_response(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.json.return_value = {"data": {"plugins": [{"id": "p1"}]}}
|
||||||
|
response.raise_for_status.return_value = None
|
||||||
|
post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response)
|
||||||
|
|
||||||
|
plugins = batch_fetch_plugin_by_ids(["p1"])
|
||||||
|
|
||||||
|
assert plugins == [{"id": "p1"}]
|
||||||
|
post_mock.assert_called_once()
|
||||||
|
response.raise_for_status.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_fetch_plugin_manifests_returns_empty_for_empty_input(mocker: MockerFixture) -> None:
|
||||||
|
post_mock = mocker.patch("core.helper.marketplace.httpx.post")
|
||||||
|
|
||||||
|
assert batch_fetch_plugin_manifests([]) == []
|
||||||
|
post_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_fetch_plugin_manifests_validates_and_returns_plugins(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.raise_for_status.return_value = None
|
||||||
|
response.json.return_value = {"data": {"plugins": [{"id": "p1"}, {"id": "p2"}]}}
|
||||||
|
post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response)
|
||||||
|
validate_mock = mocker.patch(
|
||||||
|
"core.helper.marketplace.MarketplacePluginDeclaration.model_validate",
|
||||||
|
side_effect=["manifest-1", "manifest-2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = batch_fetch_plugin_manifests(["p1", "p2"])
|
||||||
|
|
||||||
|
assert result == ["manifest-1", "manifest-2"]
|
||||||
|
post_mock.assert_called_once()
|
||||||
|
assert validate_mock.call_count == 2
|
||||||
|
response.raise_for_status.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_install_plugin_event_posts_and_checks_status(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.raise_for_status.return_value = None
|
||||||
|
post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response)
|
||||||
|
|
||||||
|
record_install_plugin_event("plugin.a")
|
||||||
|
|
||||||
|
post_mock.assert_called_once()
|
||||||
|
response.raise_for_status.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_global_plugin_manifest_caches_each_plugin(mocker: MockerFixture) -> None:
|
||||||
|
response = MagicMock()
|
||||||
|
response.raise_for_status.return_value = None
|
||||||
|
response.json.return_value = {"plugins": [{"id": "a"}, {"id": "b"}]}
|
||||||
|
mocker.patch("core.helper.marketplace.httpx.get", return_value=response)
|
||||||
|
|
||||||
|
snapshot_a = SimpleNamespace(plugin_id="plugin-a", model_dump_json=lambda: '{"id":"a"}')
|
||||||
|
snapshot_b = SimpleNamespace(plugin_id="plugin-b", model_dump_json=lambda: '{"id":"b"}')
|
||||||
|
validate_mock = mocker.patch(
|
||||||
|
"core.helper.marketplace.MarketplacePluginSnapshot.model_validate",
|
||||||
|
side_effect=[snapshot_a, snapshot_b],
|
||||||
|
)
|
||||||
|
setex_mock = mocker.patch("core.helper.marketplace.redis_client.setex")
|
||||||
|
|
||||||
|
fetch_global_plugin_manifest("prefix:", 60)
|
||||||
|
|
||||||
|
assert validate_mock.call_count == 2
|
||||||
|
setex_mock.assert_any_call(name="prefix:plugin-a", time=60, value='{"id":"a"}')
|
||||||
|
setex_mock.assert_any_call(name="prefix:plugin-b", time=60, value='{"id":"b"}')
|
||||||
158
api/tests/unit_tests/core/helper/test_moderation.py
Normal file
158
api/tests/unit_tests/core/helper/test_moderation.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.helper.moderation import check_moderation
|
||||||
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model_config(provider: str = "openai") -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
provider=provider,
|
||||||
|
provider_model_bundle=SimpleNamespace(
|
||||||
|
configuration=SimpleNamespace(using_provider_type=ProviderType.SYSTEM),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_moderation_returns_false_when_feature_not_enabled(mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.moderation.hosting_configuration",
|
||||||
|
SimpleNamespace(moderation_config=None, provider_map={}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
check_moderation(
|
||||||
|
"tenant-1",
|
||||||
|
cast(ModelConfigWithCredentialsEntity, _build_model_config()),
|
||||||
|
"hello",
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_moderation_returns_false_when_hosting_credentials_missing(mocker: MockerFixture) -> None:
|
||||||
|
openai_provider = "langgenius/openai/openai"
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.moderation.hosting_configuration",
|
||||||
|
SimpleNamespace(
|
||||||
|
moderation_config=SimpleNamespace(enabled=True, providers={"openai"}),
|
||||||
|
provider_map={openai_provider: SimpleNamespace(enabled=True, credentials=None)},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
check_moderation(
|
||||||
|
"tenant-1",
|
||||||
|
cast(ModelConfigWithCredentialsEntity, _build_model_config()),
|
||||||
|
"hello",
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFixture) -> None:
|
||||||
|
openai_provider = "langgenius/openai/openai"
|
||||||
|
hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"})
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.moderation.hosting_configuration",
|
||||||
|
SimpleNamespace(
|
||||||
|
moderation_config=SimpleNamespace(enabled=True, providers={"openai"}),
|
||||||
|
provider_map={openai_provider: hosting_openai},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||||
|
|
||||||
|
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
|
||||||
|
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||||
|
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
check_moderation(
|
||||||
|
"tenant-1",
|
||||||
|
cast(ModelConfigWithCredentialsEntity, _build_model_config()),
|
||||||
|
"abc",
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture) -> None:
|
||||||
|
openai_provider = "langgenius/openai/openai"
|
||||||
|
hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"})
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.moderation.hosting_configuration",
|
||||||
|
SimpleNamespace(
|
||||||
|
moderation_config=SimpleNamespace(enabled=True, providers={"openai"}),
|
||||||
|
provider_map={openai_provider: hosting_openai},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory")
|
||||||
|
choice_mock = mocker.patch("core.helper.moderation.secrets.choice")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
check_moderation(
|
||||||
|
"tenant-1",
|
||||||
|
cast(ModelConfigWithCredentialsEntity, _build_model_config()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
factory_mock.assert_not_called()
|
||||||
|
choice_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFixture) -> None:
|
||||||
|
openai_provider = "langgenius/openai/openai"
|
||||||
|
hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"})
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.moderation.hosting_configuration",
|
||||||
|
SimpleNamespace(
|
||||||
|
moderation_config=SimpleNamespace(enabled=True, providers={"openai"}),
|
||||||
|
provider_map={openai_provider: hosting_openai},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||||
|
|
||||||
|
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
|
||||||
|
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||||
|
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
check_moderation(
|
||||||
|
"tenant-1",
|
||||||
|
cast(ModelConfigWithCredentialsEntity, _build_model_config()),
|
||||||
|
"abc",
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: MockerFixture) -> None:
|
||||||
|
openai_provider = "langgenius/openai/openai"
|
||||||
|
hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"})
|
||||||
|
mocker.patch(
|
||||||
|
"core.helper.moderation.hosting_configuration",
|
||||||
|
SimpleNamespace(
|
||||||
|
moderation_config=SimpleNamespace(enabled=True, providers={"openai"}),
|
||||||
|
provider_map={openai_provider: hosting_openai},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||||
|
|
||||||
|
failing_model = SimpleNamespace(
|
||||||
|
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||||
|
)
|
||||||
|
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
|
||||||
|
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||||
|
|
||||||
|
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
|
||||||
|
check_moderation(
|
||||||
|
"tenant-1",
|
||||||
|
cast(ModelConfigWithCredentialsEntity, _build_model_config()),
|
||||||
|
"abc",
|
||||||
|
)
|
||||||
33
api/tests/unit_tests/core/helper/test_name_generator.py
Normal file
33
api/tests/unit_tests/core/helper/test_name_generator.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.name_generator import generate_incremental_name, generate_provider_name
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Provider:
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_incremental_name_uses_next_highest_suffix() -> None:
|
||||||
|
names = ["API KEY 1", "API KEY 3", "API KEY 2", "other", "", "API KEY x"]
|
||||||
|
|
||||||
|
assert generate_incremental_name(names, "API KEY") == "API KEY 4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_incremental_name_returns_default_when_no_matches() -> None:
|
||||||
|
assert generate_incremental_name(["custom", " ", ""], "AUTH") == "AUTH 1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_provider_name_uses_credential_display_name() -> None:
|
||||||
|
providers = [_Provider(name="API KEY 1"), _Provider(name="API KEY 2")]
|
||||||
|
|
||||||
|
assert generate_provider_name(providers, CredentialType.API_KEY) == "API KEY 3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_provider_name_falls_back_on_generation_error(mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch("core.helper.name_generator.generate_incremental_name", side_effect=RuntimeError("boom"))
|
||||||
|
|
||||||
|
assert generate_provider_name([], CredentialType.OAUTH2, fallback_context="ctx") == "AUTH 1"
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_parameter_cache_get_returns_decoded_dict(mocker: MockerFixture) -> None:
|
||||||
|
redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client")
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id="tenant",
|
||||||
|
provider="provider",
|
||||||
|
tool_name="tool",
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id="identity",
|
||||||
|
)
|
||||||
|
payload = {"k": "v", "n": 1}
|
||||||
|
cache_key = cache.cache_key
|
||||||
|
|
||||||
|
redis_client_mock.get.return_value = json.dumps(payload).encode("utf-8")
|
||||||
|
|
||||||
|
assert cache.get() == payload
|
||||||
|
redis_client_mock.get.assert_called_once_with(cache_key)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_parameter_cache_get_returns_none_for_invalid_json(mocker: MockerFixture) -> None:
|
||||||
|
redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client")
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id="tenant",
|
||||||
|
provider="provider",
|
||||||
|
tool_name="tool",
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id="identity",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client_mock.get.return_value = b"{invalid-json"
|
||||||
|
|
||||||
|
assert cache.get() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_parameter_cache_get_returns_none_when_key_is_missing(mocker: MockerFixture) -> None:
|
||||||
|
redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client")
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id="tenant",
|
||||||
|
provider="provider",
|
||||||
|
tool_name="tool",
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id="identity",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client_mock.get.return_value = None
|
||||||
|
|
||||||
|
assert cache.get() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_parameter_cache_set_and_delete(mocker: MockerFixture) -> None:
|
||||||
|
redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client")
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id="tenant",
|
||||||
|
provider="provider",
|
||||||
|
tool_name="tool",
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id="identity",
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {"a": "b"}
|
||||||
|
cache.set(params)
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
|
redis_client_mock.setex.assert_called_once_with(cache.cache_key, 86400, json.dumps(params))
|
||||||
|
redis_client_mock.delete.assert_called_once_with(cache.cache_key)
|
||||||
Reference in New Issue
Block a user