mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 15:29:21 +08:00
feat: server multi models support (#799)
This commit is contained in:
0
api/core/model_providers/providers/__init__.py
Normal file
0
api/core/model_providers/providers/__init__.py
Normal file
224
api/core/model_providers/providers/anthropic_provider.py
Normal file
224
api/core/model_providers/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type, Optional
|
||||
|
||||
import anthropic
|
||||
from flask import current_app
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.anthropic_model import AnthropicModel
|
||||
from core.model_providers.models.llm.base import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'anthropic'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = AnthropicModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'anthropic_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Anthropic API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'anthropic_api_key': credentials['anthropic_api_key']
|
||||
}
|
||||
|
||||
if 'anthropic_api_url' in credentials:
|
||||
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['anthropic_api_key'] = encrypter.encrypt_token(tenant_id, credentials['anthropic_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'anthropic_api_url': None,
|
||||
'anthropic_api_key': None
|
||||
}
|
||||
|
||||
if credentials['anthropic_api_key']:
|
||||
credentials['anthropic_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['anthropic_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['anthropic_api_key'] = encrypter.obfuscated_token(credentials['anthropic_api_key'])
|
||||
|
||||
if 'anthropic_api_url' not in credentials:
|
||||
credentials['anthropic_api_url'] = None
|
||||
|
||||
return credentials
|
||||
else:
|
||||
if hosted_model_providers.anthropic:
|
||||
return {
|
||||
'anthropic_api_url': hosted_model_providers.anthropic.api_base,
|
||||
'anthropic_api_key': hosted_model_providers.anthropic.api_key,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'anthropic_api_url': None,
|
||||
'anthropic_api_key': None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return False
|
||||
|
||||
if hosted_model_providers.anthropic:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.anthropic and \
|
||||
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.anthropic \
|
||||
and hosted_model_providers.anthropic.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
387
api/core/model_providers/providers/azure_openai_provider.py
Normal file
387
api/core/model_providers/providers/azure_openai_provider.py
Normal file
@@ -0,0 +1,387 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
import openai
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
|
||||
AZURE_OPENAI_API_VERSION
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderType, ProviderModel, ProviderQuotaType
|
||||
|
||||
BASE_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k',
|
||||
'text-davinci-003',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
class AzureOpenAIProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'azure_openai'
|
||||
|
||||
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
# convert old provider config to provider models
|
||||
self._convert_provider_config_to_model_config()
|
||||
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
# get configurable provider models
|
||||
provider_models = db.session.query(ProviderModel).filter(
|
||||
ProviderModel.tenant_id == self.provider.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider_name,
|
||||
ProviderModel.model_type == model_type.value,
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
model_list = []
|
||||
for provider_model in provider_models:
|
||||
model_dict = {
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['base_model_name'] in [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k',
|
||||
]:
|
||||
model_dict['features'] = [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
|
||||
model_list.append(model_dict)
|
||||
else:
|
||||
model_list = self._get_fixed_model_list(model_type)
|
||||
|
||||
return model_list
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
}
|
||||
]
|
||||
|
||||
if self.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
|
||||
|
||||
return models
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'text-embedding-ada-002',
|
||||
'name': 'text-embedding-ada-002'
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = AzureOpenAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = AzureOpenAIEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
base_model_max_tokens = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-35-turbo': 4096,
|
||||
'gpt-35-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
model_credentials = self.get_model_credentials(model_name, model_type)
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
|
||||
model_credentials['base_model_name'],
|
||||
4097
|
||||
), default=16),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
||||
|
||||
if 'openai_api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
||||
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
|
||||
if credentials['base_model_name'] not in BASE_MODELS:
|
||||
raise CredentialsValidateFailedError('Base Model Name is invalid')
|
||||
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
try:
|
||||
client = EnhanceAzureChatOpenAI(
|
||||
deployment_name=model_name,
|
||||
temperature=0,
|
||||
max_tokens=15,
|
||||
request_timeout=10,
|
||||
openai_api_type='azure',
|
||||
openai_api_version='2023-07-01-preview',
|
||||
openai_api_key=credentials['openai_api_key'],
|
||||
openai_api_base=credentials['openai_api_base'],
|
||||
)
|
||||
|
||||
client.generate([[HumanMessage(content='hi!')]])
|
||||
except openai.error.OpenAIError as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Azure OpenAI Model retrieve failed.")
|
||||
raise e
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
try:
|
||||
client = OpenAIEmbeddings(
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
deployment=model_name,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
openai_api_key=credentials['openai_api_key'],
|
||||
openai_api_base=credentials['openai_api_base']
|
||||
)
|
||||
|
||||
client.embed_query('hi')
|
||||
except openai.error.OpenAIError as e:
|
||||
logging.exception("Azure OpenAI Model check error.")
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Azure OpenAI Model retrieve failed.")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
# convert old provider config to provider models
|
||||
self._convert_provider_config_to_model_config()
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': '',
|
||||
'base_model_name': ''
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['openai_api_key']:
|
||||
credentials['openai_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['openai_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
if hosted_model_providers.azure_openai:
|
||||
return {
|
||||
'openai_api_base': hosted_model_providers.azure_openai.api_base,
|
||||
'openai_api_key': hosted_model_providers.azure_openai.api_key,
|
||||
'base_model_name': model_name
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'openai_api_base': None,
|
||||
'openai_api_key': None,
|
||||
'base_model_name': None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return False
|
||||
|
||||
if hosted_model_providers.azure_openai:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.azure_openai \
|
||||
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
|
||||
def _convert_provider_config_to_model_config(self):
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
and self.provider.is_valid \
|
||||
and self.provider.encrypted_config:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': '',
|
||||
'base_model_name': ''
|
||||
}
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-35-turbo',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-35-turbo-16k',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-4',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='text-davinci-003',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='text-embedding-ada-002',
|
||||
model_type=ModelType.EMBEDDINGS,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self.provider.encrypted_config = None
|
||||
db.session.commit()
|
||||
|
||||
def _add_provider_model(self, model_name: str, model_type: ModelType, provider_credentials: dict):
|
||||
credentials = provider_credentials.copy()
|
||||
credentials['base_model_name'] = model_name
|
||||
provider_model = ProviderModel(
|
||||
tenant_id=self.provider.tenant_id,
|
||||
provider_name=self.provider.provider_name,
|
||||
model_name=model_name,
|
||||
model_type=model_type.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_model)
|
||||
db.session.commit()
|
||||
283
api/core/model_providers/providers/base.py
Normal file
283
api/core/model_providers/providers/base.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Type, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_providers.error import QuotaExceededError, LLMBadRequestError
|
||||
from extensions.ext_database import db
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.models.entity.provider import ProviderQuotaUnit
|
||||
from core.model_providers.rules import provider_rules
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
class BaseModelProvider(BaseModel, ABC):
|
||||
|
||||
provider: Provider
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_rules(self):
|
||||
"""
|
||||
Returns the rules of a provider.
|
||||
"""
|
||||
return provider_rules[self.provider_name]
|
||||
|
||||
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
"""
|
||||
get supported model object list for use.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
rules = self.get_rules()
|
||||
if 'custom' not in rules['support_provider_types']:
|
||||
return self._get_fixed_model_list(model_type)
|
||||
|
||||
if 'model_flexibility' not in rules:
|
||||
return self._get_fixed_model_list(model_type)
|
||||
|
||||
if rules['model_flexibility'] == 'fixed':
|
||||
return self._get_fixed_model_list(model_type)
|
||||
|
||||
# get configurable provider models
|
||||
provider_models = db.session.query(ProviderModel).filter(
|
||||
ProviderModel.tenant_id == self.provider.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider_name,
|
||||
ProviderModel.model_type == model_type.value,
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
return [{
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
} for provider_model in provider_models]
|
||||
|
||||
@abstractmethod
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
"""
|
||||
get supported model object list for use.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_class(self, model_type: ModelType) -> Type:
|
||||
"""
|
||||
get specific model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
check provider credentials valid.
|
||||
|
||||
:param credentials:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt provider credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
return current_app.config['EDITION'] == 'CLOUD'
|
||||
|
||||
def check_quota_over_limit(self):
|
||||
"""
|
||||
check provider quota over limit.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.SYSTEM.value:
|
||||
return
|
||||
|
||||
rules = self.get_rules()
|
||||
if 'system' not in rules['support_provider_types']:
|
||||
return
|
||||
|
||||
provider = db.session.query(Provider).filter(
|
||||
db.and_(
|
||||
Provider.id == self.provider.id,
|
||||
Provider.is_valid == True,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
)
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise QuotaExceededError()
|
||||
|
||||
def deduct_quota(self, used_tokens: int = 0) -> None:
|
||||
"""
|
||||
deduct available quota when provider type is system or paid.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.SYSTEM.value:
|
||||
return
|
||||
|
||||
rules = self.get_rules()
|
||||
if 'system' not in rules['support_provider_types']:
|
||||
return
|
||||
|
||||
if not self.should_deduct_quota():
|
||||
return
|
||||
|
||||
if 'system_config' not in rules:
|
||||
quota_unit = ProviderQuotaUnit.TIMES.value
|
||||
elif 'quota_unit' not in rules['system_config']:
|
||||
quota_unit = ProviderQuotaUnit.TIMES.value
|
||||
else:
|
||||
quota_unit = rules['system_config']['quota_unit']
|
||||
|
||||
if quota_unit == ProviderQuotaUnit.TOKENS.value:
|
||||
used_quota = used_tokens
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.provider.tenant_id,
|
||||
Provider.provider_name == self.provider.provider_name,
|
||||
Provider.provider_type == self.provider.provider_type,
|
||||
Provider.quota_type == self.provider.quota_type,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return False
|
||||
|
||||
def update_last_used(self) -> None:
|
||||
"""
|
||||
update last used time.
|
||||
|
||||
:return:
|
||||
"""
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.provider.tenant_id,
|
||||
Provider.provider_name == self.provider.provider_name
|
||||
).update({'last_used': datetime.utcnow()})
|
||||
db.session.commit()
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
||||
"""
|
||||
get provider model.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
provider_model = db.session.query(ProviderModel).filter(
|
||||
ProviderModel.tenant_id == self.provider.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider_name,
|
||||
ProviderModel.model_name == model_name,
|
||||
ProviderModel.model_type == model_type.value,
|
||||
ProviderModel.is_valid == True
|
||||
).first()
|
||||
|
||||
if not provider_model:
|
||||
raise LLMBadRequestError(f"The model {model_name} does not exist. "
|
||||
f"Please check the configuration.")
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
class CredentialsValidateFailedError(Exception):
|
||||
pass
|
||||
157
api/core/model_providers/providers/chatglm_provider.py
Normal file
157
api/core/model_providers/providers/chatglm_provider.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class ChatGLMProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'chatglm'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = ChatGLMModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'chatglm-6b': 2000,
|
||||
'chatglm2-6b': 32000,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'endpoint_url': credentials['api_base']
|
||||
}
|
||||
|
||||
llm = ChatGLM(
|
||||
max_token=10,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_base'] = encrypter.encrypt_token(tenant_id, credentials['api_base'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_base': None
|
||||
}
|
||||
|
||||
if credentials['api_base']:
|
||||
credentials['api_base'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_base']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_base'] = encrypter.obfuscated_token(credentials['api_base'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
76
api/core/model_providers/providers/hosted.py
Normal file
76
api/core/model_providers/providers/hosted.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import langchain
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HostedOpenAI(BaseModel):
|
||||
api_base: str = None
|
||||
api_organization: str = None
|
||||
api_key: str
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the openai hosted model. 0 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
|
||||
|
||||
class HostedAzureOpenAI(BaseModel):
|
||||
api_base: str
|
||||
api_key: str
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the azure openai hosted model. 0 means unlimited."""
|
||||
|
||||
|
||||
class HostedAnthropic(BaseModel):
|
||||
api_base: str = None
|
||||
api_key: str
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
|
||||
|
||||
class HostedModelProviders(BaseModel):
|
||||
openai: Optional[HostedOpenAI] = None
|
||||
azure_openai: Optional[HostedAzureOpenAI] = None
|
||||
anthropic: Optional[HostedAnthropic] = None
|
||||
|
||||
|
||||
hosted_model_providers = HostedModelProviders()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
|
||||
if app.config.get("HOSTED_OPENAI_ENABLED"):
|
||||
hosted_model_providers.openai = HostedOpenAI(
|
||||
api_base=app.config.get("HOSTED_OPENAI_API_BASE"),
|
||||
api_organization=app.config.get("HOSTED_OPENAI_API_ORGANIZATION"),
|
||||
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
hosted_model_providers.azure_openai = HostedAzureOpenAI(
|
||||
api_base=app.config.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
api_key=app.config.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_ANTHROPIC_ENABLED"):
|
||||
hosted_model_providers.anthropic = HostedAnthropic(
|
||||
api_base=app.config.get("HOSTED_ANTHROPIC_API_BASE"),
|
||||
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
||||
)
|
||||
183
api/core/model_providers/providers/huggingface_hub_provider.py
Normal file
183
api/core/model_providers/providers/huggingface_hub_provider.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class HuggingfaceHubProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'huggingface_hub'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = HuggingfaceHubModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if model_type != ModelType.TEXT_GENERATION:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'huggingfacehub_api_type' not in credentials \
|
||||
or credentials['huggingfacehub_api_type'] not in ['hosted_inference_api', 'inference_endpoints']:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub API Type invalid, '
|
||||
'must be hosted_inference_api or inference_endpoints.')
|
||||
|
||||
if 'huggingfacehub_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub API Token must be provided.')
|
||||
|
||||
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
|
||||
|
||||
try:
|
||||
hfapi.whoami()
|
||||
except Exception:
|
||||
raise CredentialsValidateFailedError("Invalid API Token.")
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task="text2text-generation",
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
else:
|
||||
try:
|
||||
model_info = hfapi.model_info(repo_id=model_name)
|
||||
if not model_info:
|
||||
raise ValueError(f'Model {model_name} not found.')
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['huggingfacehub_api_token'] = encrypter.encrypt_token(tenant_id, credentials['huggingfacehub_api_token'])
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
|
||||
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
|
||||
model_info = hfapi.model_info(repo_id=model_name)
|
||||
if not model_info:
|
||||
raise ValueError(f'Model {model_name} not found.')
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
credentials['task_type'] = model_info.pipeline_tag
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'huggingfacehub_api_token': None,
|
||||
'task_type': None
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['huggingfacehub_api_token']:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.obfuscated_token(credentials['huggingfacehub_api_token'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
179
api/core/model_providers/providers/minimax_provider.py
Normal file
179
api/core/model_providers/providers/minimax_provider.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import Minimax
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.minimax_model import MinimaxModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class MinimaxProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'minimax'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'abab5.5-chat',
|
||||
'name': 'abab5.5-chat',
|
||||
},
|
||||
{
|
||||
'id': 'abab5-chat',
|
||||
'name': 'abab5-chat',
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'embo-01',
|
||||
'name': 'embo-01',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = MinimaxModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = MinimaxEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'abab5.5-chat': 16384,
|
||||
'abab5-chat': 6144,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.95),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'minimax_group_id' not in credentials:
|
||||
raise CredentialsValidateFailedError('MiniMax Group ID must be provided.')
|
||||
|
||||
if 'minimax_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('MiniMax API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'minimax_group_id': credentials['minimax_group_id'],
|
||||
'minimax_api_key': credentials['minimax_api_key'],
|
||||
}
|
||||
|
||||
llm = Minimax(
|
||||
model='abab5.5-chat',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['minimax_api_key'] = encrypter.encrypt_token(tenant_id, credentials['minimax_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'minimax_group_id': None,
|
||||
'minimax_api_key': None,
|
||||
}
|
||||
|
||||
if credentials['minimax_api_key']:
|
||||
credentials['minimax_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['minimax_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['minimax_api_key'] = encrypter.obfuscated_token(credentials['minimax_api_key'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
289
api/core/model_providers/providers/openai_provider.py
Normal file
289
api/core/model_providers/providers/openai_provider.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type, Optional
|
||||
|
||||
from flask import current_app
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
import openai
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class OpenAIProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'openai'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
}
|
||||
]
|
||||
|
||||
if self.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
|
||||
|
||||
return models
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'text-embedding-ada-002',
|
||||
'name': 'text-embedding-ada-002'
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.SPEECH_TO_TEXT:
|
||||
return [
|
||||
{
|
||||
'id': 'whisper-1',
|
||||
'name': 'whisper-1'
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'text-moderation-stable',
|
||||
'name': 'text-moderation-stable'
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = OpenAIEmbedding
|
||||
elif model_type == ModelType.MODERATION:
|
||||
model_class = OpenAIModeration
|
||||
elif model_type == ModelType.SPEECH_TO_TEXT:
|
||||
model_class = OpenAIWhisper
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('OpenAI API key is required')
|
||||
|
||||
try:
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['openai_api_key']
|
||||
}
|
||||
|
||||
if 'openai_api_base' in credentials and credentials['openai_api_base']:
|
||||
credentials_kwargs['api_base'] = credentials['openai_api_base'] + '/v1'
|
||||
|
||||
if 'openai_organization' in credentials:
|
||||
credentials_kwargs['organization'] = credentials['openai_organization']
|
||||
|
||||
openai.ChatCompletion.create(
|
||||
messages=[{"role": "user", "content": 'ping'}],
|
||||
model='gpt-3.5-turbo',
|
||||
timeout=10,
|
||||
request_timeout=(5, 30),
|
||||
max_tokens=20,
|
||||
**credentials_kwargs
|
||||
)
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'openai_api_base': None,
|
||||
'openai_api_key': self.provider.encrypted_config,
|
||||
'openai_organization': None
|
||||
}
|
||||
|
||||
if credentials['openai_api_key']:
|
||||
credentials['openai_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['openai_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
|
||||
|
||||
if 'openai_api_base' not in credentials or not credentials['openai_api_base']:
|
||||
credentials['openai_api_base'] = None
|
||||
else:
|
||||
credentials['openai_api_base'] = credentials['openai_api_base'] + '/v1'
|
||||
|
||||
if 'openai_organization' not in credentials:
|
||||
credentials['openai_organization'] = None
|
||||
|
||||
return credentials
|
||||
else:
|
||||
if hosted_model_providers.openai:
|
||||
return {
|
||||
'openai_api_base': hosted_model_providers.openai.api_base,
|
||||
'openai_api_key': hosted_model_providers.openai.api_key,
|
||||
'openai_organization': hosted_model_providers.openai.api_organization
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'openai_api_base': None,
|
||||
'openai_api_key': None,
|
||||
'openai_organization': None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return False
|
||||
|
||||
if hosted_model_providers.openai:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.openai \
|
||||
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get payment info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.openai \
|
||||
and hosted_model_providers.openai.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
184
api/core/model_providers/providers/replicate_provider.py
Normal file
184
api/core/model_providers/providers/replicate_provider.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import replicate
|
||||
from replicate.exceptions import ReplicateError
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.replicate_model import ReplicateModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class ReplicateProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'replicate'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = ReplicateModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = ReplicateEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_credentials = self.get_model_credentials(model_name, model_type)
|
||||
|
||||
model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
|
||||
|
||||
try:
|
||||
version = model.versions.get(model_credentials['model_version'])
|
||||
except ReplicateError as e:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
|
||||
f"cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Model validate failed.")
|
||||
raise e
|
||||
|
||||
model_kwargs_rules = ModelKwargsRules()
|
||||
for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
|
||||
if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
|
||||
if key == ['temperature', 'top_p']:
|
||||
kwarg_rule = KwargRule[float](
|
||||
type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
|
||||
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
|
||||
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
|
||||
default=float(value.get('default')) if value.get('default') is not None else None,
|
||||
)
|
||||
if key == 'temperature':
|
||||
model_kwargs_rules.temperature = kwarg_rule
|
||||
else:
|
||||
model_kwargs_rules.top_p = kwarg_rule
|
||||
elif key in ['max_length', 'max_new_tokens']:
|
||||
model_kwargs_rules.max_tokens = KwargRule[int](
|
||||
alias=key,
|
||||
type=KwargRuleType.INTEGER.value,
|
||||
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
|
||||
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
|
||||
default=int(value.get('default')) if value.get('default') is not None else 500,
|
||||
)
|
||||
|
||||
return model_kwargs_rules
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'replicate_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate API Key must be provided.')
|
||||
|
||||
if 'model_version' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
||||
|
||||
if model_name.count("/") != 1:
|
||||
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
|
||||
'format: {user_name}/{model_name}')
|
||||
|
||||
version = credentials['model_version']
|
||||
try:
|
||||
model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
|
||||
rst = model.versions.get(version)
|
||||
|
||||
if model_type == ModelType.EMBEDDINGS \
|
||||
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
|
||||
elif model_type == ModelType.TEXT_GENERATION \
|
||||
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
|
||||
except ReplicateError as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Replicate config validation failed.")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'replicate_api_token': None,
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['replicate_api_token']:
|
||||
credentials['replicate_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['replicate_api_token']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
191
api/core/model_providers/providers/spark_provider.py
Normal file
191
api/core/model_providers/providers/spark_provider.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from flask import current_app
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.spark_model import SparkModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.spark import ChatSpark
|
||||
from core.third_party.spark.spark_llm import SparkError
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class SparkProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'spark'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'spark',
|
||||
'name': '星火认知大模型',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = SparkModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.5),
|
||||
top_p=KwargRule[float](enabled=False),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'app_id' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark app_id must be provided.')
|
||||
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark api_key must be provided.')
|
||||
|
||||
if 'api_secret' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'app_id': credentials['app_id'],
|
||||
'api_key': credentials['api_key'],
|
||||
'api_secret': credentials['api_secret'],
|
||||
}
|
||||
|
||||
chat_llm = ChatSpark(
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except SparkError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('Spark config validation failed')
|
||||
raise ex
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
credentials['api_secret'] = encrypter.encrypt_token(tenant_id, credentials['api_secret'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'app_id': None,
|
||||
'api_key': None,
|
||||
'api_secret': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
if credentials['api_secret']:
|
||||
credentials['api_secret'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_secret']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_secret'] = encrypter.obfuscated_token(credentials['api_secret'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {
|
||||
'app_id': None,
|
||||
'api_key': None,
|
||||
'api_secret': None,
|
||||
}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
157
api/core/model_providers/providers/tongyi_provider.py
Normal file
157
api/core/model_providers/providers/tongyi_provider.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.tongyi_model import TongyiModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class TongyiProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'tongyi'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'qwen-v1',
|
||||
'name': 'qwen-v1',
|
||||
},
|
||||
{
|
||||
'id': 'qwen-plus-v1',
|
||||
'name': 'qwen-plus-v1',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = TongyiModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'qwen-v1': 1500,
|
||||
'qwen-plus-v1': 6500
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'dashscope_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Dashscope API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'dashscope_api_key': credentials['dashscope_api_key']
|
||||
}
|
||||
|
||||
llm = EnhanceTongyi(
|
||||
model_name='qwen-v1',
|
||||
max_retries=1,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['dashscope_api_key'] = encrypter.encrypt_token(tenant_id, credentials['dashscope_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'dashscope_api_key': None
|
||||
}
|
||||
|
||||
if credentials['dashscope_api_key']:
|
||||
credentials['dashscope_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['dashscope_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['dashscope_api_key'] = encrypter.obfuscated_token(credentials['dashscope_api_key'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
182
api/core/model_providers/providers/wenxin_provider.py
Normal file
182
api/core/model_providers/providers/wenxin_provider.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.wenxin_model import WenxinModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.wenxin import Wenxin
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class WenxinProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'wenxin'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'ernie-bot',
|
||||
'name': 'ERNIE-Bot',
|
||||
},
|
||||
{
|
||||
'id': 'ernie-bot-turbo',
|
||||
'name': 'ERNIE-Bot-turbo',
|
||||
},
|
||||
{
|
||||
'id': 'bloomz-7b',
|
||||
'name': 'BLOOMZ-7B',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = WenxinModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
|
||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](enabled=False),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Wenxin api_key must be provided.')
|
||||
|
||||
if 'secret_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Wenxin secret_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
'secret_key': credentials['secret_key'],
|
||||
}
|
||||
|
||||
llm = Wenxin(
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
'secret_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
if credentials['secret_key']:
|
||||
credentials['secret_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['secret_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {
|
||||
'api_key': None,
|
||||
'secret_key': None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
Reference in New Issue
Block a user