feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -0,0 +1,44 @@
from typing import Type
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.base import BaseModelProvider
class FakeModelProvider(BaseModelProvider):
@property
def provider_name(self):
return 'fake'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [{'id': 'test_model', 'name': 'Test Model'}]
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
return OpenAIModel
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
pass
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
pass
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
return credentials
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
return ModelKwargsRules()
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
return {}

View File

@@ -0,0 +1,123 @@
from typing import List, Optional, Any
import anthropic
import httpx
import pytest
from unittest.mock import patch
import json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
from core.model_providers.providers.anthropic_provider import AnthropicProvider
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'anthropic'
MODEL_PROVIDER_CLASS = AnthropicProvider
VALIDATE_CREDENTIAL_KEY = 'anthropic_api_key'
def mock_chat_generate(messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content='answer'))])
def mock_chat_generate_invalid(messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
raise anthropic.APIStatusError('Invalid credentials',
request=httpx._models.Request(
method='POST',
url='https://api.anthropic.com/v1/completions',
),
response=httpx._models.Response(
status_code=401,
),
body=None
)
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate)
def test_is_provider_credentials_valid_or_raise_valid(mock_create):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'valid_key'})
@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate_invalid)
def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
# raise CredentialsValidateFailedError if anthropic_api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
# raise CredentialsValidateFailedError if anthropic_api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
api_key = 'valid_key'
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{api_key}'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
assert len(middle_token) == max(len(api_key) - 8, 0)
assert all(char == '*' for char in middle_token)
@patch('core.model_providers.providers.hosted.hosted_model_providers.anthropic')
def test_get_credentials_hosted(mock_hosted):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.SYSTEM.value,
encrypted_config='',
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
mock_hosted.api_key = 'hosted_key'
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'

View File

@@ -0,0 +1,117 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'azure_openai'
MODEL_PROVIDER_CLASS = AzureOpenAIProvider
VALIDATE_CREDENTIAL = {
'openai_api_base': 'https://xxxx.openai.azure.com/',
'openai_api_key': 'valid_key',
'base_model_name': 'gpt-35-turbo'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_model_credentials_valid_or_raise(mocker):
mocker.patch('langchain.chat_models.base.BaseChatModel.generate', return_value=None)
# assert True if credentials is valid
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL
)
def test_is_model_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if credentials is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
openai_api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={'openai_api_key': openai_api_key}
)
mock_encrypt.assert_called_with('tenant_id', openai_api_key)
assert result['openai_api_key'] == f'encrypted_{openai_api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['openai_api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['openai_api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['openai_api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,72 @@
from unittest.mock import MagicMock
import pytest
from core.model_providers.error import QuotaExceededError
from core.model_providers.models.entity.model_params import ModelType
from models.provider import Provider, ProviderType
from tests.unit_tests.model_providers.fake_model_provider import FakeModelProvider
def test_get_supported_model_list(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['custom'], 'model_flexibility': 'configurable'}
)
mock_provider_model = MagicMock()
mock_provider_model.model_name = 'test_model'
mock_query = MagicMock()
mock_query.filter.return_value.order_by.return_value.all.return_value = [mock_provider_model]
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
provider = FakeModelProvider(provider=Provider())
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
assert result == [{'id': 'test_model', 'name': 'test_model'}]
def test_check_quota_over_limit(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['system']}
)
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = None
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
with pytest.raises(QuotaExceededError):
provider.check_quota_over_limit()
def test_check_quota_not_over_limit(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['system']}
)
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = Provider()
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
assert provider.check_quota_over_limit() is None
def test_check_custom_quota_over_limit(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['custom']}
)
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.CUSTOM.value))
assert provider.check_quota_over_limit() is None

View File

@@ -0,0 +1,89 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'chatglm'
MODEL_PROVIDER_CLASS = ChatGLMProvider
VALIDATE_CREDENTIAL = {
'api_base': 'valid_api_base',
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.chatglm.ChatGLM._call',
return_value="abc")
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_base'] = 'invalid_api_base'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_base'] == f'encrypted_{VALIDATE_CREDENTIAL["api_base"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_base'] == 'valid_api_base'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_base'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_base']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,161 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'huggingface_hub'
MODEL_PROVIDER_CLASS = HuggingfaceHubProvider
HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'valid_key'
}
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'valid_key',
'huggingfacehub_endpoint_url': 'valid_url'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
@patch('huggingface_hub.hf_api.ModelInfo')
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL
)
@patch('huggingface_hub.hf_api.ModelInfo')
def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_info):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
})
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
)
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_endpoint_url': 'valid_url'
})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['huggingfacehub_api_token'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['huggingfacehub_api_token'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['huggingfacehub_api_token'][6:-2]
assert len(middle_token) == max(len(INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL['huggingfacehub_api_token']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,88 @@
import pytest
from unittest.mock import patch
import json
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'minimax'
MODEL_PROVIDER_CLASS = MinimaxProvider
VALIDATE_CREDENTIAL = {
'minimax_group_id': 'fake-group-id',
'minimax_api_key': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.minimax.Minimax._call', return_value='abc')
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['minimax_api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['minimax_api_key'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['minimax_api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['minimax_api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['minimax_api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,126 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from openai.error import AuthenticationError
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'openai'
MODEL_PROVIDER_CLASS = OpenAIProvider
VALIDATE_CREDENTIAL_KEY = 'openai_api_key'
def moderation_side_effect(*args, **kwargs):
if kwargs['api_key'] == 'valid_key':
mock_instance = MagicMock()
mock_instance.request = MagicMock()
return mock_instance, {}
else:
raise AuthenticationError('Invalid credentials')
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
def test_is_provider_credentials_valid_or_raise_valid(mock_create):
# assert True if api_key is valid
credentials = {VALIDATE_CREDENTIAL_KEY: 'valid_key'}
assert MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credentials) is None
@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom_str(mock_decrypt):
"""
Only the OpenAI provider needs to be compatible with the previous case where the encrypted_config was stored as a plain string.
:param mock_decrypt:
:return:
"""
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config='encrypted_valid_key',
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
openai_api_key = 'valid_key'
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{openai_api_key}'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
assert len(middle_token) == max(len(openai_api_key) - 8, 0)
assert all(char == '*' for char in middle_token)
@patch('core.model_providers.providers.hosted.hosted_model_providers.openai')
def test_get_credentials_hosted(mock_hosted):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.SYSTEM.value,
encrypted_config='',
is_valid=True
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
mock_hosted.api_key = 'hosted_key'
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'

View File

@@ -0,0 +1,125 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'replicate'
MODEL_PROVIDER_CLASS = ReplicateProvider
VALIDATE_CREDENTIAL = {
'model_version': 'fake-version',
'replicate_api_token': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mock_query = MagicMock()
mock_query.return_value = None
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if replicate_api_token is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
# raise CredentialsValidateFailedError if replicate_api_token is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={'replicate_api_token': 'invalid_key'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['replicate_api_token'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['replicate_api_token'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['replicate_api_token'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['replicate_api_token']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,97 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'spark'
MODEL_PROVIDER_CLASS = SparkProvider
VALIDATE_CREDENTIAL = {
'app_id': 'valid_app_id',
'api_key': 'valid_key',
'api_secret': 'valid_secret'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.spark.ChatSpark._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content="abc"))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert result['api_secret'] == f'encrypted_{VALIDATE_CREDENTIAL["api_secret"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
assert result['api_secret'] == 'valid_secret'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
middle_secret = result['api_secret'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['api_secret']) - 8, 0)
assert all(char == '*' for char in middle_token)
assert all(char == '*' for char in middle_secret)

View File

@@ -0,0 +1,90 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import LLMResult, Generation
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.minimax_provider import MinimaxProvider
from core.model_providers.providers.tongyi_provider import TongyiProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'tongyi'
MODEL_PROVIDER_CLASS = TongyiProvider
VALIDATE_CREDENTIAL = {
'dashscope_api_key': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['dashscope_api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['dashscope_api_key'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['dashscope_api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['dashscope_api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['dashscope_api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,93 @@
import pytest
from unittest.mock import patch
import json
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'wenxin'
MODEL_PROVIDER_CLASS = WenxinProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key',
'secret_key': 'valid_secret'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
assert result['secret_key'] == 'valid_secret'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
middle_secret = result['secret_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
assert all(char == '*' for char in middle_token)
assert all(char == '*' for char in middle_secret)