mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 04:59:23 +08:00
feat: server multi models support (#799)
This commit is contained in:
@@ -2,42 +2,11 @@ import re
|
||||
import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.constant import llm_constant
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
MODEL_PROVIDERS = [
|
||||
'openai',
|
||||
'anthropic',
|
||||
]
|
||||
|
||||
MODELS_BY_APP_MODE = {
|
||||
'chat': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
'text-davinci-003',
|
||||
]
|
||||
}
|
||||
|
||||
SUPPORT_AGENT_MODELS = [
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
]
|
||||
|
||||
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
||||
@@ -65,40 +34,40 @@ class AppModelConfigService:
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
|
||||
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
llm_constant.max_context_token_length[model_name]:
|
||||
raise ValueError(
|
||||
"max_tokens must be an integer greater than 0 "
|
||||
"and not exceeding the maximum value of the corresponding model")
|
||||
|
||||
#
|
||||
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
# llm_constant.max_context_token_length[model_name]:
|
||||
# raise ValueError(
|
||||
# "max_tokens must be an integer greater than 0 "
|
||||
# "and not exceeding the maximum value of the corresponding model")
|
||||
#
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
|
||||
if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
raise ValueError("temperature must be a float between 0 and 2")
|
||||
|
||||
#
|
||||
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
# raise ValueError("temperature must be a float between 0 and 2")
|
||||
#
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
raise ValueError("top_p must be a float between 0 and 2")
|
||||
|
||||
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
# raise ValueError("top_p must be a float between 0 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
|
||||
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
# raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
# raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
@@ -112,7 +81,7 @@ class AppModelConfigService:
|
||||
return filtered_cp
|
||||
|
||||
@staticmethod
|
||||
def validate_configuration(account: Account, config: dict, mode: str) -> dict:
|
||||
def validate_configuration(tenant_id: str, account: Account, config: dict) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
@@ -211,14 +180,21 @@ class AppModelConfigService:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
|
||||
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
|
||||
model_provider_names = ModelProviderFactory.get_provider_names()
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
|
||||
if not model_provider:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.completion_params
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import io
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from core.llm.whisper import Whisper
|
||||
from models.provider import ProviderName
|
||||
|
||||
FILE_SIZE = 15
|
||||
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
|
||||
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
|
||||
|
||||
|
||||
class AudioService:
|
||||
@classmethod
|
||||
def transcript(cls, tenant_id: str, file: FileStorage):
|
||||
@@ -26,14 +24,12 @@ class AudioService:
|
||||
if file_size > FILE_SIZE_LIMIT:
|
||||
message = f"Audio size larger than {FILE_SIZE} mb"
|
||||
raise AudioTooLargeServiceError(message)
|
||||
|
||||
provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
|
||||
if provider_name != ProviderName.OPENAI.value:
|
||||
raise ProviderNotSupportSpeechToTextServiceError()
|
||||
|
||||
provider_service = LLMProviderService(tenant_id, provider_name)
|
||||
model = ModelFactory.get_speech2text_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = 'temp.mp3'
|
||||
|
||||
return Whisper(provider_service.provider).transcribe(buffer)
|
||||
return model.run(buffer)
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import and_
|
||||
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -127,9 +127,9 @@ class CompletionService:
|
||||
|
||||
# validate config
|
||||
model_config = AppModelConfigService.validate_configuration(
|
||||
tenant_id=app_model.tenant_id,
|
||||
account=user,
|
||||
config=args['model_config'],
|
||||
mode=app_model.mode
|
||||
config=args['model_config']
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
|
||||
@@ -9,8 +9,7 @@ from typing import Optional, List
|
||||
from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from events.event_handlers.document_index_event import document_index_created
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -875,8 +874,13 @@ class SegmentService:
|
||||
content = args['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=document.tenant_id
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
@@ -921,8 +925,13 @@ class SegmentService:
|
||||
update_segment_keyword_index_task.delay(segment.id)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=document.tenant_id
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content)
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
|
||||
@@ -4,14 +4,13 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
@@ -29,15 +28,11 @@ class HitTestingService:
|
||||
"records": []
|
||||
}
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
|
||||
158
api/services/provider_checkout_service.py
Normal file
158
api/services/provider_checkout_service.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import current_app
|
||||
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType
|
||||
|
||||
|
||||
class ProviderCheckout:
|
||||
def __init__(self, stripe_checkout_session):
|
||||
self.stripe_checkout_session = stripe_checkout_session
|
||||
|
||||
def get_checkout_url(self):
|
||||
return self.stripe_checkout_session.url
|
||||
|
||||
|
||||
class ProviderCheckoutService:
|
||||
def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout:
|
||||
# check provider name is valid
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
if provider_name not in model_provider_rules:
|
||||
raise ValueError(f'provider name {provider_name} is invalid')
|
||||
|
||||
model_provider_rule = model_provider_rules[provider_name]
|
||||
|
||||
# check provider name can be paid
|
||||
self._check_provider_payable(provider_name, model_provider_rule)
|
||||
|
||||
# get stripe checkout product id
|
||||
paid_provider = self._get_paid_provider(tenant_id, provider_name)
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
model_provider = model_provider_class(provider=paid_provider)
|
||||
payment_info = model_provider.get_payment_info()
|
||||
if not payment_info:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
payment_product_id = payment_info['product_id']
|
||||
|
||||
# create provider order
|
||||
provider_order = ProviderOrder(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
account_id=account.id,
|
||||
payment_product_id=payment_product_id,
|
||||
quantity=1,
|
||||
payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value
|
||||
)
|
||||
|
||||
db.session.add(provider_order)
|
||||
db.session.flush()
|
||||
|
||||
try:
|
||||
# create stripe checkout session
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
line_items=[
|
||||
{
|
||||
'price': f'{payment_product_id}',
|
||||
'quantity': 1,
|
||||
},
|
||||
],
|
||||
mode='payment',
|
||||
success_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=succeeded',
|
||||
cancel_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=cancelled',
|
||||
automatic_tax={'enabled': True},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later')
|
||||
|
||||
provider_order.payment_id = checkout_session.id
|
||||
db.session.commit()
|
||||
|
||||
return ProviderCheckout(checkout_session)
|
||||
|
||||
def fulfill_provider_order(self, event):
|
||||
provider_order = db.session.query(ProviderOrder) \
|
||||
.filter(ProviderOrder.payment_id == event['data']['object']['id']) \
|
||||
.first()
|
||||
|
||||
if not provider_order:
|
||||
raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')
|
||||
|
||||
if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
|
||||
raise ValueError(f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
|
||||
|
||||
provider_order.transaction_id = event['data']['object']['payment_intent']
|
||||
provider_order.currency = event['data']['object']['currency']
|
||||
provider_order.total_amount = event['data']['object']['amount_subtotal']
|
||||
provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value
|
||||
provider_order.paid_at = datetime.datetime.utcnow()
|
||||
provider_order.updated_at = provider_order.paid_at
|
||||
|
||||
# update provider quota
|
||||
provider = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == provider_order.tenant_id,
|
||||
Provider.provider_name == provider_order.provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.PAID.value
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, '
|
||||
f'provider name: {provider_order.provider_name}')
|
||||
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name)
|
||||
model_provider = model_provider_class(provider=provider)
|
||||
payment_info = model_provider.get_payment_info()
|
||||
|
||||
if not payment_info:
|
||||
increase_quota = 0
|
||||
else:
|
||||
increase_quota = int(payment_info['increase_quota'])
|
||||
|
||||
if increase_quota > 0:
|
||||
provider.quota_limit += increase_quota
|
||||
provider.is_valid = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _check_provider_payable(self, provider_name: str, model_provider_rule: dict):
|
||||
if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
if 'system_config' not in model_provider_rule:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
if 'supported_quota_types' not in model_provider_rule['system_config']:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
if 'paid' not in model_provider_rule['system_config']['supported_quota_types']:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
def _get_paid_provider(self, tenant_id: str, provider_name: str):
|
||||
paid_provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.PAID.value,
|
||||
).first()
|
||||
|
||||
if not paid_provider:
|
||||
paid_provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.PAID.value,
|
||||
quota_limit=0,
|
||||
quota_used=0,
|
||||
)
|
||||
db.session.add(paid_provider)
|
||||
db.session.commit()
|
||||
|
||||
return paid_provider
|
||||
@@ -1,88 +1,503 @@
|
||||
from typing import Union
|
||||
import datetime
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from models.account import Tenant
|
||||
from models.provider import *
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
|
||||
TenantDefaultModel
|
||||
|
||||
|
||||
class ProviderService:
|
||||
|
||||
@staticmethod
|
||||
def init_supported_provider(tenant):
|
||||
"""Initialize the model provider, check whether the supported provider has a record"""
|
||||
def get_provider_list(self, tenant_id: str):
|
||||
"""
|
||||
get provider list of tenant.
|
||||
|
||||
need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
# get rules for all providers
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
|
||||
configurable_model_provider_names = [
|
||||
model_provider_name
|
||||
for model_provider_name, model_provider_rules in model_provider_rules.items()
|
||||
if 'custom' in model_provider_rules['support_provider_types']
|
||||
and model_provider_rules['model_flexibility'] == 'configurable'
|
||||
]
|
||||
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(need_init_provider_names)
|
||||
# get all providers for the tenant
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name.in_(model_provider_names),
|
||||
Provider.is_valid == True
|
||||
).order_by(Provider.created_at.desc()).all()
|
||||
|
||||
provider_name_to_provider_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
provider_name_to_provider_dict[provider.provider_name].append(provider)
|
||||
|
||||
# get all configurable provider models for the tenant
|
||||
provider_models = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name.in_(configurable_model_provider_names),
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.desc()).all()
|
||||
|
||||
provider_name_to_provider_model_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
# get all preferred provider type for the tenant
|
||||
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(model_provider_names)
|
||||
).all()
|
||||
|
||||
exists_provider_names = []
|
||||
for provider in providers:
|
||||
exists_provider_names.append(provider.provider_name)
|
||||
provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types}
|
||||
|
||||
not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
|
||||
providers_list = {}
|
||||
|
||||
if not_exists_provider_names:
|
||||
# Initialize the model provider, check whether the supported provider has a record
|
||||
for provider_name in not_exists_provider_names:
|
||||
provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
# get preferred provider type
|
||||
preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
|
||||
preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
|
||||
tenant_id,
|
||||
model_provider_name,
|
||||
preferred_model_provider
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
provider_config_dict = {
|
||||
"preferred_provider_type": preferred_provider_type,
|
||||
"model_flexibility": model_provider_rule['model_flexibility'],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
|
||||
provider_parameter_dict = {}
|
||||
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
|
||||
for quota_type_enum in ProviderQuotaType:
|
||||
quota_type = quota_type_enum.value
|
||||
if quota_type in model_provider_rule['system_config']['supported_quota_types']:
|
||||
key = ProviderType.SYSTEM.value + ':' + quota_type
|
||||
provider_parameter_dict[key] = {
|
||||
"provider_name": model_provider_name,
|
||||
"provider_type": ProviderType.SYSTEM.value,
|
||||
"config": None,
|
||||
"is_valid": False, # need update
|
||||
"quota_type": quota_type,
|
||||
"quota_unit": model_provider_rule['system_config']['quota_unit'], # need update
|
||||
"quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
|
||||
model_provider_rule['system_config']['quota_limit'], # need update
|
||||
"quota_used": 0, # need update
|
||||
"last_used": None # need update
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_token_type(tenant, provider_name: ProviderName):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_token_type()
|
||||
if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
|
||||
provider_parameter_dict[ProviderType.CUSTOM.value] = {
|
||||
"provider_name": model_provider_name,
|
||||
"provider_type": ProviderType.CUSTOM.value,
|
||||
"config": None, # need update
|
||||
"models": [], # need update
|
||||
"is_valid": False,
|
||||
"last_used": None # need update
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
|
||||
if current_app.config['DISABLE_PROVIDER_CONFIG_VALIDATION']:
|
||||
return
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.config_validate(configs)
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
||||
|
||||
@staticmethod
|
||||
def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_encrypted_token(configs)
|
||||
current_providers = provider_name_to_provider_dict[model_provider_name]
|
||||
for provider in current_providers:
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_type = provider.quota_type
|
||||
key = f'{ProviderType.SYSTEM.value}:{quota_type}'
|
||||
|
||||
@staticmethod
|
||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
|
||||
is_valid: bool = True):
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return
|
||||
if key in provider_parameter_dict:
|
||||
provider_parameter_dict[key]['is_valid'] = provider.is_valid
|
||||
provider_parameter_dict[key]['quota_used'] = provider.quota_used
|
||||
provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
|
||||
provider_parameter_dict[key]['last_used'] = provider.last_used
|
||||
elif provider.provider_type == ProviderType.CUSTOM.value \
|
||||
and ProviderType.CUSTOM.value in provider_parameter_dict:
|
||||
# if custom
|
||||
key = ProviderType.CUSTOM.value
|
||||
provider_parameter_dict[key]['last_used'] = provider.last_used
|
||||
provider_parameter_dict[key]['is_valid'] = provider.is_valid
|
||||
|
||||
provider = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
if model_provider_rule['model_flexibility'] == 'fixed':
|
||||
provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
|
||||
.get_provider_credentials(obfuscated=True)
|
||||
else:
|
||||
models = []
|
||||
provider_models = provider_name_to_provider_model_dict[model_provider_name]
|
||||
for provider_model in provider_models:
|
||||
models.append({
|
||||
"model_name": provider_model.model_name,
|
||||
"model_type": provider_model.model_type,
|
||||
"config": model_provider_class(provider=provider) \
|
||||
.get_model_credentials(provider_model.model_name,
|
||||
ModelType.value_of(provider_model.model_type),
|
||||
obfuscated=True),
|
||||
"is_valid": provider_model.is_valid
|
||||
})
|
||||
provider_parameter_dict[key]['models'] = models
|
||||
|
||||
provider_config_dict['providers'] = list(provider_parameter_dict.values())
|
||||
providers_list[model_provider_name] = provider_config_dict
|
||||
|
||||
return providers_list
|
||||
|
||||
def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
|
||||
"""
|
||||
validate custom provider config.
|
||||
|
||||
:param provider_name:
|
||||
:param config:
|
||||
:return:
|
||||
:raises CredentialsValidateFailedError: When the config credential verification fails.
|
||||
"""
|
||||
# get model provider rules
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
|
||||
|
||||
if model_provider_rules['model_flexibility'] != 'fixed':
|
||||
raise ValueError('Only support fixed model provider')
|
||||
|
||||
# only support provider type CUSTOM
|
||||
if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
|
||||
raise ValueError('Only support provider type CUSTOM')
|
||||
|
||||
# validate provider config
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
model_provider_class.is_provider_credentials_valid_or_raise(config)
|
||||
|
||||
def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
|
||||
"""
|
||||
save custom provider config.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param config:
|
||||
:return:
|
||||
"""
|
||||
# validate custom provider config
|
||||
self.custom_provider_config_validate(provider_name, config)
|
||||
|
||||
# get provider
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value
|
||||
).one_or_none()
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
|
||||
|
||||
# save provider
|
||||
if provider:
|
||||
provider.encrypted_config = json.dumps(encrypted_config)
|
||||
provider.is_valid = True
|
||||
provider.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
else:
|
||||
provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota_limit,
|
||||
encrypted_config='',
|
||||
is_valid=is_valid,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_config),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
|
||||
"""
|
||||
delete custom provider.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
if provider:
|
||||
try:
|
||||
self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
def custom_provider_model_config_validate(self,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config: dict) -> None:
|
||||
"""
|
||||
validate custom provider model config.
|
||||
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param config:
|
||||
:return:
|
||||
:raises CredentialsValidateFailedError: When the config credential verification fails.
|
||||
"""
|
||||
# get model provider rules
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
|
||||
|
||||
if model_provider_rules['model_flexibility'] != 'configurable':
|
||||
raise ValueError('Only support configurable model provider')
|
||||
|
||||
# only support provider type CUSTOM
|
||||
if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
|
||||
raise ValueError('Only support provider type CUSTOM')
|
||||
|
||||
# validate provider model config
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
|
||||
|
||||
def add_or_save_custom_provider_model_config(self,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
config: dict) -> None:
|
||||
"""
|
||||
Add or save custom provider model config.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param config:
|
||||
:return:
|
||||
"""
|
||||
# validate custom provider model config
|
||||
self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
|
||||
|
||||
# get provider
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
elif not provider.is_valid:
|
||||
provider.is_valid = True
|
||||
provider.encrypted_config = None
|
||||
db.session.commit()
|
||||
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
encrypted_config = model_provider_class.encrypt_model_credentials(
|
||||
tenant_id,
|
||||
model_name,
|
||||
ModelType.value_of(model_type),
|
||||
config
|
||||
)
|
||||
|
||||
# get provider model
|
||||
provider_model = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider_name,
|
||||
ProviderModel.model_name == model_name,
|
||||
ProviderModel.model_type == model_type
|
||||
).first()
|
||||
|
||||
if provider_model:
|
||||
provider_model.encrypted_config = json.dumps(encrypted_config)
|
||||
provider_model.is_valid = True
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model = ProviderModel(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
encrypted_config=json.dumps(encrypted_config),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_model)
|
||||
db.session.commit()
|
||||
|
||||
def delete_custom_provider_model(self,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str) -> None:
|
||||
"""
|
||||
delete custom provider model.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider_name,
|
||||
ProviderModel.model_name == model_name,
|
||||
ProviderModel.model_type == model_type
|
||||
).first()
|
||||
|
||||
if provider_model:
|
||||
db.session.delete(provider_model)
|
||||
db.session.commit()
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
|
||||
"""
|
||||
switch preferred provider.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider_name:
|
||||
:param preferred_provider_type:
|
||||
:return:
|
||||
"""
|
||||
provider_type = ProviderType.value_of(preferred_provider_type)
|
||||
if not provider_type:
|
||||
raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
|
||||
if preferred_provider_type not in model_provider_rules['support_provider_types']:
|
||||
raise ValueError(f'Not support provider type: {preferred_provider_type}')
|
||||
|
||||
model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
if not model_provider.is_provider_type_system_supported():
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == provider_name
|
||||
).first()
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = preferred_provider_type
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
preferred_provider_type=preferred_provider_type
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
|
||||
|
||||
def update_default_model_of_model_type(self,
|
||||
tenant_id: str,
|
||||
model_type: str,
|
||||
provider_name: str,
|
||||
model_name: str) -> TenantDefaultModel:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
|
||||
|
||||
def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
|
||||
"""
|
||||
get valid model list.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
valid_model_list = []
|
||||
|
||||
# get model provider rules
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
continue
|
||||
|
||||
model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
|
||||
provider = model_provider.provider
|
||||
for model in model_list:
|
||||
valid_model_dict = {
|
||||
"model_name": model['id'],
|
||||
"model_type": model_type,
|
||||
"model_provider": {
|
||||
"provider_name": provider.provider_name,
|
||||
"provider_type": provider.provider_type
|
||||
},
|
||||
'features': []
|
||||
}
|
||||
|
||||
if 'features' in model:
|
||||
valid_model_dict['features'] = model['features']
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
valid_model_dict['model_provider']['quota_type'] = provider.quota_type
|
||||
valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
|
||||
valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
|
||||
valid_model_dict['model_provider']['quota_used'] = provider.quota_used
|
||||
|
||||
valid_model_list.append(valid_model_dict)
|
||||
|
||||
return valid_model_list
|
||||
|
||||
def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
|
||||
-> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
It depends on preferred provider in use.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
# get empty model provider
|
||||
return ModelKwargsRules()
|
||||
|
||||
# get model parameter rules
|
||||
return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
class WorkspaceService:
|
||||
@@ -13,8 +13,8 @@ class WorkspaceService:
|
||||
'status': tenant.status,
|
||||
'created_at': tenant.created_at,
|
||||
'providers': [],
|
||||
'in_trail': False,
|
||||
'trial_end_reason': 'using_custom'
|
||||
'in_trial': True,
|
||||
'trial_end_reason': None
|
||||
}
|
||||
|
||||
# Get providers
|
||||
@@ -25,25 +25,4 @@ class WorkspaceService:
|
||||
# Add providers to the tenant info
|
||||
tenant_info['providers'] = providers
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value:
|
||||
if provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
|
||||
if system_provider and not custom_provider:
|
||||
quota_used = system_provider.quota_used if system_provider.quota_used is not None else 0
|
||||
quota_limit = system_provider.quota_limit if system_provider.quota_limit is not None else 0
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
tenant_info['trial_end_reason'] = 'trial_exceeded'
|
||||
else:
|
||||
tenant_info['in_trail'] = True
|
||||
tenant_info['trial_end_reason'] = None
|
||||
|
||||
return tenant_info
|
||||
|
||||
Reference in New Issue
Block a user