mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 15:29:21 +08:00
Initial commit
This commit is contained in:
55
api/core/llm/error.py
Normal file
55
api/core/llm/error.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class LLMBadRequestError(LLMError):
|
||||
"""Raised when the LLM returns bad request."""
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class LLMAPIConnectionError(LLMError):
|
||||
"""Raised when the LLM returns API connection error."""
|
||||
description = "API Connection Error"
|
||||
|
||||
|
||||
class LLMAPIUnavailableError(LLMError):
|
||||
"""Raised when the LLM returns API unavailable error."""
|
||||
description = "API Unavailable Error"
|
||||
|
||||
|
||||
class LLMRateLimitError(LLMError):
|
||||
"""Raised when the LLM returns rate limit error."""
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class LLMAuthorizationError(LLMError):
|
||||
"""Raised when the LLM returns authorization error."""
|
||||
description = "Authorization Error"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
description = "Model Currently Not Support"
|
||||
51
api/core/llm/error_handle_wraps.py
Normal file
51
api/core/llm/error_handle_wraps.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import openai
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_llm_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def handle_llm_exceptions_async(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
|
||||
return wrapper
|
||||
103
api/core/llm/llm_builder.py
Normal file
103
api/core/llm/llm_builder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
"""
|
||||
This class handles the following logic:
|
||||
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
|
||||
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
|
||||
OPENAI_API_TYPE=azure
|
||||
OPENAI_API_VERSION=2022-12-01
|
||||
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
|
||||
OPENAI_API_KEY=<your Azure OpenAI API key>
|
||||
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
|
||||
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
|
||||
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
|
||||
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
|
||||
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
|
||||
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
|
||||
if model_name == 'fake':
|
||||
return FakeListLLM(responses=[])
|
||||
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
# llm_cls = StreamableAzureChatOpenAI
|
||||
llm_cls = StreamableChatOpenAI
|
||||
elif mode == 'completion':
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, model_name)
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
top_p=kwargs.get('top_p', 1),
|
||||
frequency_penalty=kwargs.get('frequency_penalty', 0),
|
||||
presence_penalty=kwargs.get('presence_penalty', 0),
|
||||
callback_manager=kwargs.get('callback_manager', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
model_name = model.get("name")
|
||||
completion_params = model.get("completion_params", {})
|
||||
|
||||
return cls.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=model_name,
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||
streaming=streaming,
|
||||
callback_manager=callback_manager
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mode_by_model(cls, model_name: str) -> str:
|
||||
if not model_name:
|
||||
raise ValueError(f"empty model name is not supported.")
|
||||
|
||||
if model_name in llm_constant.models_by_mode['chat']:
|
||||
return "chat"
|
||||
elif model_name in llm_constant.models_by_mode['completion']:
|
||||
return "completion"
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
|
||||
"""
|
||||
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
|
||||
Raises an exception if the model_name is not found or if the provider is not found.
|
||||
"""
|
||||
if not model_name:
|
||||
raise Exception('model name not found')
|
||||
|
||||
if model_name not in llm_constant.models:
|
||||
raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
model_provider = llm_constant.models[model_name]
|
||||
|
||||
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
||||
return provider_service.get_credentials(model_name)
|
||||
15
api/core/llm/moderation.py
Normal file
15
api/core/llm/moderation.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import openai
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class Moderation:
|
||||
|
||||
def __init__(self, provider: str, api_key: str):
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
|
||||
if self.provider == ProviderName.OPENAI.value:
|
||||
self.client = openai.Moderation
|
||||
|
||||
def moderate(self, text):
|
||||
return self.client.create(input=text, api_key=self.api_key)
|
||||
23
api/core/llm/provider/anthropic_provider.py
Normal file
23
api/core/llm/provider/anthropic_provider.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
||||
"""
|
||||
return {
|
||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
105
api/core/llm/provider/azure_provider.py
Normal file
105
api/core/llm/provider/azure_provider.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
url = "{}/openai/deployments?api-version={}".format(
|
||||
credentials.get('openai_api_base'),
|
||||
credentials.get('openai_api_version')
|
||||
)
|
||||
|
||||
headers = {
|
||||
"api-key": credentials.get('openai_api_key'),
|
||||
"content-type": "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return [{
|
||||
'id': deployment['id'],
|
||||
'name': '{} ({})'.format(deployment['id'], deployment['model'])
|
||||
} for deployment in result['data'] if deployment['status'] == 'succeeded']
|
||||
else:
|
||||
# TODO: optimize in future
|
||||
raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
encrypted_config = self.get_provider_api_key(model_id=model_id)
|
||||
config = json.loads(encrypted_config)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = json.loads(config)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return lambda value: value
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
|
||||
return config
|
||||
124
api/core/llm/provider/base.py
Normal file
124
api/core/llm/provider/base.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(prefer_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
|
||||
if model_id and model_id == 'gpt-4':
|
||||
raise ModelCurrentlyNotSupportError()
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
raise QuotaExceededError()
|
||||
|
||||
return self.get_hosted_credentials()
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.get_provider_name().value
|
||||
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
system_provider = provider
|
||||
|
||||
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
|
||||
return custom_provider
|
||||
elif system_provider and system_provider.is_valid:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> str:
|
||||
if self.get_provider_name() != ProviderName.OPENAI:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
except:
|
||||
config = 'THIS-IS-A-MOCK-TOKEN'
|
||||
|
||||
if obfuscated:
|
||||
return self.obfuscated_token(config)
|
||||
|
||||
return config
|
||||
|
||||
def obfuscated_token(self, token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
def get_token_type(self):
|
||||
return str
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.encrypt_token(config)
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
return self.decrypt_token(token)
|
||||
|
||||
def encrypt_token(self, token):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token):
|
||||
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def config_validate(self, config: str):
|
||||
raise NotImplementedError
|
||||
2
api/core/llm/provider/errors.py
Normal file
2
api/core/llm/provider/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class ValidateFailedError(Exception):
|
||||
description = "Provider Validate failed"
|
||||
22
api/core/llm/provider/huggingface_provider.py
Normal file
22
api/core/llm/provider/huggingface_provider.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class HuggingfaceProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
|
||||
"""
|
||||
return {
|
||||
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.HUGGINGFACEHUB
|
||||
53
api/core/llm/provider/llm_provider_service.py
Normal file
53
api/core/llm/provider/llm_provider_service.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.llm.provider.anthropic_provider import AnthropicProvider
|
||||
from core.llm.provider.azure_provider import AzureProvider
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.huggingface_provider import HuggingfaceProvider
|
||||
from core.llm.provider.openai_provider import OpenAIProvider
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
class LLMProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self.init_provider(tenant_id, provider_name)
|
||||
|
||||
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
|
||||
if provider_name == 'openai':
|
||||
return OpenAIProvider(tenant_id)
|
||||
elif provider_name == 'azure_openai':
|
||||
return AzureProvider(tenant_id)
|
||||
elif provider_name == 'anthropic':
|
||||
return AnthropicProvider(tenant_id)
|
||||
elif provider_name == 'huggingface':
|
||||
return HuggingfaceProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('provider {} not found'.format(provider_name))
|
||||
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return self.provider.get_models(model_id)
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated)
|
||||
|
||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
return self.provider.get_provider(prefer_custom)
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
|
||||
:param config:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.config_validate(config)
|
||||
|
||||
def get_token_type(self):
|
||||
return self.provider.get_token_type()
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.provider.get_encrypted_token(config)
|
||||
44
api/core/llm/provider/openai_provider.py
Normal file
44
api/core/llm/provider/openai_provider.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
response = openai.Model.list(**credentials)
|
||||
|
||||
return [{
|
||||
'id': model['id'],
|
||||
'name': model['id'],
|
||||
} for model in response['data']]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the credentials for the given tenant_id and provider_name.
|
||||
"""
|
||||
return {
|
||||
'openai_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.OPENAI
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
Moderation(self.get_provider_name().value, config).moderate('test')
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise ValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
89
api/core/llm/streamable_azure_chat_open_ai.py
Normal file
89
api/core/llm/streamable_azure_chat_open_ai.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import requests
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop)
|
||||
86
api/core/llm/streamable_chat_open_ai.py
Normal file
86
api/core/llm/streamable_chat_open_ai.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop)
|
||||
20
api/core/llm/streamable_open_ai.py
Normal file
20
api/core/llm/streamable_open_ai.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List
|
||||
from langchain import OpenAI
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop)
|
||||
41
api/core/llm/token_calculator.py
Normal file
41
api/core/llm/token_calculator.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import decimal
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
@classmethod
|
||||
def get_num_tokens(cls, model_name: str, text: str):
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
@classmethod
|
||||
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
|
||||
if model_name in llm_constant.models_by_mode['embedding']:
|
||||
unit_price = llm_constant.model_prices[model_name]['usage']
|
||||
elif text_type == 'prompt':
|
||||
unit_price = llm_constant.model_prices[model_name]['prompt']
|
||||
elif text_type == 'completion':
|
||||
unit_price = llm_constant.model_prices[model_name]['completion']
|
||||
else:
|
||||
raise Exception('Invalid text type')
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
@classmethod
|
||||
def get_currency(cls, model_name: str):
|
||||
return llm_constant.model_currency
|
||||
Reference in New Issue
Block a user