mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 15:29:21 +08:00
feat: server multi models support (#799)
This commit is contained in:
0
api/tests/unit_tests/model_providers/__init__.py
Normal file
0
api/tests/unit_tests/model_providers/__init__.py
Normal file
44
api/tests/unit_tests/model_providers/fake_model_provider.py
Normal file
44
api/tests/unit_tests/model_providers/fake_model_provider.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Type
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class FakeModelProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
return 'fake'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return [{'id': 'test_model', 'name': 'Test Model'}]
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
return OpenAIModel
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
return credentials
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
return ModelKwargsRules()
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
123
api/tests/unit_tests/model_providers/test_anthropic_provider.py
Normal file
123
api/tests/unit_tests/model_providers/test_anthropic_provider.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
|
||||
|
||||
from core.model_providers.providers.anthropic_provider import AnthropicProvider
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'anthropic'
|
||||
MODEL_PROVIDER_CLASS = AnthropicProvider
|
||||
VALIDATE_CREDENTIAL_KEY = 'anthropic_api_key'
|
||||
|
||||
|
||||
def mock_chat_generate(messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content='answer'))])
|
||||
|
||||
|
||||
def mock_chat_generate_invalid(messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any):
|
||||
raise anthropic.APIStatusError('Invalid credentials',
|
||||
request=httpx._models.Request(
|
||||
method='POST',
|
||||
url='https://api.anthropic.com/v1/completions',
|
||||
),
|
||||
response=httpx._models.Response(
|
||||
status_code=401,
|
||||
),
|
||||
body=None
|
||||
)
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate)
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mock_create):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'valid_key'})
|
||||
|
||||
|
||||
@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate_invalid)
|
||||
def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
|
||||
# raise CredentialsValidateFailedError if anthropic_api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
# raise CredentialsValidateFailedError if anthropic_api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
api_key = 'valid_key'
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{api_key}'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
|
||||
assert len(middle_token) == max(len(api_key) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
|
||||
|
||||
@patch('core.model_providers.providers.hosted.hosted_model_providers.anthropic')
|
||||
def test_get_credentials_hosted(mock_hosted):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
mock_hosted.api_key = 'hosted_key'
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'
|
||||
@@ -0,0 +1,117 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
PROVIDER_NAME = 'azure_openai'
|
||||
MODEL_PROVIDER_CLASS = AzureOpenAIProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'openai_api_base': 'https://xxxx.openai.azure.com/',
|
||||
'openai_api_key': 'valid_key',
|
||||
'base_model_name': 'gpt-35-turbo'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_model_credentials_valid_or_raise(mocker):
|
||||
mocker.patch('langchain.chat_models.base.BaseChatModel.generate', return_value=None)
|
||||
|
||||
# assert True if credentials is valid
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL
|
||||
)
|
||||
|
||||
|
||||
def test_is_model_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if credentials is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt):
|
||||
openai_api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={'openai_api_key': openai_api_key}
|
||||
)
|
||||
mock_encrypt.assert_called_with('tenant_id', openai_api_key)
|
||||
assert result['openai_api_key'] == f'encrypted_{openai_api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
assert result['openai_api_key'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
obfuscated=True
|
||||
)
|
||||
middle_token = result['openai_api_key'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['openai_api_key']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
@@ -0,0 +1,72 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_providers.error import QuotaExceededError
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.provider import Provider, ProviderType
|
||||
from tests.unit_tests.model_providers.fake_model_provider import FakeModelProvider
|
||||
|
||||
|
||||
def test_get_supported_model_list(mocker):
|
||||
mocker.patch.object(
|
||||
FakeModelProvider,
|
||||
'get_rules',
|
||||
return_value={'support_provider_types': ['custom'], 'model_flexibility': 'configurable'}
|
||||
)
|
||||
|
||||
mock_provider_model = MagicMock()
|
||||
mock_provider_model.model_name = 'test_model'
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.order_by.return_value.all.return_value = [mock_provider_model]
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
provider = FakeModelProvider(provider=Provider())
|
||||
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model'}]
|
||||
|
||||
|
||||
def test_check_quota_over_limit(mocker):
|
||||
mocker.patch.object(
|
||||
FakeModelProvider,
|
||||
'get_rules',
|
||||
return_value={'support_provider_types': ['system']}
|
||||
)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = None
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
provider.check_quota_over_limit()
|
||||
|
||||
|
||||
def test_check_quota_not_over_limit(mocker):
|
||||
mocker.patch.object(
|
||||
FakeModelProvider,
|
||||
'get_rules',
|
||||
return_value={'support_provider_types': ['system']}
|
||||
)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = Provider()
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
|
||||
|
||||
assert provider.check_quota_over_limit() is None
|
||||
|
||||
|
||||
def test_check_custom_quota_over_limit(mocker):
|
||||
mocker.patch.object(
|
||||
FakeModelProvider,
|
||||
'get_rules',
|
||||
return_value={'support_provider_types': ['custom']}
|
||||
)
|
||||
|
||||
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.CUSTOM.value))
|
||||
|
||||
assert provider.check_quota_over_limit() is None
|
||||
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
from core.model_providers.providers.spark_provider import SparkProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'chatglm'
|
||||
MODEL_PROVIDER_CLASS = ChatGLMProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'api_base': 'valid_api_base',
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.chatglm.ChatGLM._call',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['api_base'] = 'invalid_api_base'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
assert result['api_base'] == f'encrypted_{VALIDATE_CREDENTIAL["api_base"]}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['api_base'] == 'valid_api_base'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['api_base'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_base']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
@@ -0,0 +1,161 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
PROVIDER_NAME = 'huggingface_hub'
|
||||
MODEL_PROVIDER_CLASS = HuggingfaceHubProvider
|
||||
HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'valid_key'
|
||||
}
|
||||
|
||||
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'valid_key',
|
||||
'huggingfacehub_endpoint_url': 'valid_url'
|
||||
}
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
@patch('huggingface_hub.hf_api.ModelInfo')
|
||||
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
|
||||
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
|
||||
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL
|
||||
)
|
||||
|
||||
@patch('huggingface_hub.hf_api.ModelInfo')
|
||||
def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_info):
|
||||
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
})
|
||||
|
||||
|
||||
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
|
||||
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
|
||||
)
|
||||
|
||||
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
|
||||
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_endpoint_url': 'valid_url'
|
||||
})
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['huggingfacehub_api_token'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
assert result['huggingfacehub_api_token'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
obfuscated=True
|
||||
)
|
||||
middle_token = result['huggingfacehub_api_token'][6:-2]
|
||||
assert len(middle_token) == max(len(INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL['huggingfacehub_api_token']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'minimax'
|
||||
MODEL_PROVIDER_CLASS = MinimaxProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'minimax_group_id': 'fake-group-id',
|
||||
'minimax_api_key': 'valid_key'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.minimax.Minimax._call', return_value='abc')
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['minimax_api_key'] = 'invalid_key'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['minimax_api_key'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['minimax_api_key'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['minimax_api_key'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['minimax_api_key']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
126
api/tests/unit_tests/model_providers/test_openai_provider.py
Normal file
126
api/tests/unit_tests/model_providers/test_openai_provider.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
PROVIDER_NAME = 'openai'
|
||||
MODEL_PROVIDER_CLASS = OpenAIProvider
|
||||
VALIDATE_CREDENTIAL_KEY = 'openai_api_key'
|
||||
|
||||
|
||||
def moderation_side_effect(*args, **kwargs):
|
||||
if kwargs['api_key'] == 'valid_key':
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.request = MagicMock()
|
||||
return mock_instance, {}
|
||||
else:
|
||||
raise AuthenticationError('Invalid credentials')
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mock_create):
|
||||
# assert True if api_key is valid
|
||||
credentials = {VALIDATE_CREDENTIAL_KEY: 'valid_key'}
|
||||
assert MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credentials) is None
|
||||
|
||||
|
||||
@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
|
||||
def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom_str(mock_decrypt):
|
||||
"""
|
||||
Only the OpenAI provider needs to be compatible with the previous case where the encrypted_config was stored as a plain string.
|
||||
|
||||
:param mock_decrypt:
|
||||
:return:
|
||||
"""
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config='encrypted_valid_key',
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
openai_api_key = 'valid_key'
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{openai_api_key}'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
|
||||
assert len(middle_token) == max(len(openai_api_key) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
|
||||
|
||||
@patch('core.model_providers.providers.hosted.hosted_model_providers.openai')
|
||||
def test_get_credentials_hosted(mock_hosted):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
mock_hosted.api_key = 'hosted_key'
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'
|
||||
125
api/tests/unit_tests/model_providers/test_replicate_provider.py
Normal file
125
api/tests/unit_tests/model_providers/test_replicate_provider.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
PROVIDER_NAME = 'replicate'
|
||||
MODEL_PROVIDER_CLASS = ReplicateProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'model_version': 'fake-version',
|
||||
'replicate_api_token': 'valid_key'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mock_query = MagicMock()
|
||||
mock_query.return_value = None
|
||||
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
|
||||
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if replicate_api_token is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={}
|
||||
)
|
||||
|
||||
# raise CredentialsValidateFailedError if replicate_api_token is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials={'replicate_api_token': 'invalid_key'})
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['replicate_api_token'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
assert result['replicate_api_token'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=None,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
encrypted_config=json.dumps(encrypted_credential)
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_model_credentials(
|
||||
model_name='test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
obfuscated=True
|
||||
)
|
||||
middle_token = result['replicate_api_token'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['replicate_api_token']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
97
api/tests/unit_tests/model_providers/test_spark_provider.py
Normal file
97
api/tests/unit_tests/model_providers/test_spark_provider.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.spark_provider import SparkProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'spark'
|
||||
MODEL_PROVIDER_CLASS = SparkProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'app_id': 'valid_app_id',
|
||||
'api_key': 'valid_key',
|
||||
'api_secret': 'valid_secret'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('core.third_party.langchain.llms.spark.ChatSpark._generate',
|
||||
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content="abc"))]))
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['api_key'] = 'invalid_key'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
|
||||
assert result['api_secret'] == f'encrypted_{VALIDATE_CREDENTIAL["api_secret"]}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['api_key'] == 'valid_key'
|
||||
assert result['api_secret'] == 'valid_secret'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['api_key'][6:-2]
|
||||
middle_secret = result['api_secret'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
|
||||
assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['api_secret']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
assert all(char == '*' for char in middle_secret)
|
||||
90
api/tests/unit_tests/model_providers/test_tongyi_provider.py
Normal file
90
api/tests/unit_tests/model_providers/test_tongyi_provider.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from langchain.schema import LLMResult, Generation
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
from core.model_providers.providers.tongyi_provider import TongyiProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'tongyi'
|
||||
MODEL_PROVIDER_CLASS = TongyiProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'dashscope_api_key': 'valid_key'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['dashscope_api_key'] = 'invalid_key'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
api_key = 'valid_key'
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||
assert result['dashscope_api_key'] == f'encrypted_{api_key}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['dashscope_api_key'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['dashscope_api_key'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['dashscope_api_key']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
93
api/tests/unit_tests/model_providers/test_wenxin_provider.py
Normal file
93
api/tests/unit_tests/model_providers/test_wenxin_provider.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'wenxin'
|
||||
MODEL_PROVIDER_CLASS = WenxinProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'api_key': 'valid_key',
|
||||
'secret_key': 'valid_secret'
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['api_key'] = 'invalid_key'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
|
||||
assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['api_key'] == 'valid_key'
|
||||
assert result['secret_key'] == 'valid_secret'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['api_key'][6:-2]
|
||||
middle_secret = result['secret_key'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
|
||||
assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
assert all(char == '*' for char in middle_secret)
|
||||
Reference in New Issue
Block a user