diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 3861fb8e994..bfcc9a7f0a2 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource): parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource): parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 35fc61e48ad..f174fcc5d39 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource): model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, + provider=provider, + model=args["model"], + model_type=args["model_type"], + config_from=args.get("config_from", ""), ) if args.get("config_from", "") == "predefined-model": @@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource): choices=[mt.value for mt in ModelType], location="json", ) - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource): ) parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index ca3c36b8783..b74e081dd46 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,5 +1,6 @@ import json import logging +import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError @@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel): with Session(db.engine) as new_session: return _validate(new_session) - def create_provider_credential(self, credentials: dict, credential_name: str) -> None: + def _generate_provider_credential_name(self, session) -> str: + """ + Generate a unique credential name for provider. + :return: credential name + """ + return self._generate_next_api_key_name( + session=session, + query_factory=lambda: select(ProviderCredential).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ), + ) + + def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str: + """ + Generate a unique credential name for custom model. + :return: credential name + """ + return self._generate_next_api_key_name( + session=session, + query_factory=lambda: select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ), + ) + + def _generate_next_api_key_name(self, session, query_factory) -> str: + """ + Generate next available API KEY name by finding the highest numbered suffix. + :param session: database session + :param query_factory: function that returns the SQLAlchemy query + :return: next available API KEY name + """ + try: + stmt = query_factory() + credential_records = session.execute(stmt).scalars().all() + + if not credential_records: + return "API KEY 1" + + # Extract numbers from API KEY pattern using list comprehension + pattern = re.compile(r"^API KEY\s+(\d+)$") + numbers = [ + int(match.group(1)) + for cr in credential_records + if cr.credential_name and (match := pattern.match(cr.credential_name.strip())) + ] + + # Return next sequential number + next_number = max(numbers, default=0) + 1 + return f"API KEY {next_number}" + + except Exception as e: + logger.warning("Error generating next credential name: %s", str(e)) + return "API KEY 1" + + def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: """ Add custom provider credentials. :param credentials: provider credentials @@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if credential_name and self._check_provider_credential_name_exists( + credential_name=credential_name, session=session + ): raise ValueError(f"Credential with name '{credential_name}' already exists.") + else: + credential_name = self._generate_provider_credential_name(session) credentials = self.validate_provider_credentials(credentials=credentials, session=session) provider_record = self._get_provider_record(session) @@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel): self, credentials: dict, credential_id: str, - credential_name: str, + credential_name: str | None, ) -> None: """ update a saved provider credential (by credential_id). @@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_provider_credential_name_exists( + if credential_name and self._check_provider_credential_name_exists( credential_name=credential_name, session=session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") @@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel): try: # Update credential credential_record.encrypted_config = json.dumps(credentials) - credential_record.credential_name = credential_name credential_record.updated_at = naive_utc_now() - + if credential_name: + credential_record.credential_name = credential_name session.commit() if provider_record and provider_record.credential_id == credential_id: @@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel): cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() - - lb_config.credential_id = None - lb_config.encrypted_config = None - lb_config.enabled = False - lb_config.name = "__delete__" - lb_config.updated_at = naive_utc_now() - session.add(lb_config) + session.delete(lb_config) # Check if this is the currently active credential provider_record = self._get_provider_record(session) @@ -822,7 +879,7 @@ class ProviderConfiguration(BaseModel): return _validate(new_session) def create_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str + self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None ) -> None: """ Create a custom model credential. @@ -833,10 +890,14 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_custom_model_credential_name_exists( + if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, session=session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + else: + credential_name = self._generate_custom_model_credential_name( + model=model, model_type=model_type, session=session + ) # validate custom model config credentials = self.validate_custom_model_credentials( model_type=model_type, model=model, credentials=credentials, session=session @@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel): raise def update_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str + self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str ) -> None: """ Update a custom model credential. @@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel): :return: """ with Session(db.engine) as session: - if self._check_custom_model_credential_name_exists( + if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, @@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel): try: # Update credential credential_record.encrypted_config = json.dumps(credentials) - credential_record.credential_name = credential_name credential_record.updated_at = naive_utc_now() + if credential_name: + credential_record.credential_name = credential_name session.commit() if provider_model_record and provider_model_record.credential_id == credential_id: @@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel): cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() - lb_config.credential_id = None - lb_config.encrypted_config = None - lb_config.enabled = False - lb_config.name = "__delete__" - lb_config.updated_at = naive_utc_now() - session.add(lb_config) + session.delete(lb_config) # Check if this is the currently active credential provider_model_record = self._get_custom_model_record(model_type, model, session=session) @@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), + is_valid=True, credential_id=credential_id, ) else: @@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel): if config.credential_source_type != "custom_model" ] - if len(provider_model_lb_configs) > 1: - load_balancing_enabled = True - - if any(config.name == "__delete__" for config in provider_model_lb_configs): - has_invalid_load_balancing_configs = True + load_balancing_enabled = model_setting.load_balancing_enabled + # when the user enable load_balancing but available configs are less than 2 display warning + has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2 provider_models.append( ModelWithProviderEntity( @@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel): for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue + if model_configuration.unadded_to_model_list: + continue if model and model != model_configuration.model: continue try: @@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel): if config.credential_source_type != "provider" ] - if len(custom_model_lb_configs) > 1: - load_balancing_enabled = True - - if any(config.name == "__delete__" for config in custom_model_lb_configs): - has_invalid_load_balancing_configs = True + load_balancing_enabled = model_setting.load_balancing_enabled + # when the user enable load_balancing but available configs are less than 2 display warning + has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2 if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: status = ModelStatus.CREDENTIAL_REMOVED diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 1b87bffe574..79a7514bbc7 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel): current_credential_id: Optional[str] = None current_credential_name: Optional[str] = None available_model_credentials: list[CredentialConfiguration] = [] + unadded_to_model_list: Optional[bool] = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) +class UnaddedModelConfiguration(BaseModel): + """ + Model class for provider unadded model configuration. + """ + + model: str + model_type: ModelType + + class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. @@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel): provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] + can_added_models: list[UnaddedModelConfiguration] = [] class ModelLoadBalancingConfiguration(BaseModel): @@ -144,6 +155,7 @@ class ModelSettings(BaseModel): model: str model_type: ModelType enabled: bool = True + load_balancing_enabled: bool = False load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] # pydantic configs diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 04996442ca7..f8ef0c18465 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,8 +1,9 @@ import contextlib import json from collections import defaultdict +from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, Optional +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -22,6 +23,7 @@ from core.entities.provider_entities import ( QuotaConfiguration, QuotaUnit, SystemConfiguration, + UnaddedModelConfiguration, ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType @@ -537,6 +539,23 @@ class ProviderManager: for credential in available_credentials ] + @staticmethod + def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]: + """ + Get all the credentials records from ProviderModelCredential by provider_name + + :param tenant_id: workspace id + :param provider_name: provider name + + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name + ) + + all_credentials = session.scalars(stmt).all() + return all_credentials + @staticmethod def _init_trial_provider_records( tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] @@ -623,6 +642,44 @@ class ProviderManager: :param provider_model_records: provider model records :return: """ + # Get custom provider configuration + custom_provider_configuration = self._get_custom_provider_configuration( + tenant_id, provider_entity, provider_records + ) + + # Get all model credentials once + all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider) + + # Get custom models which have not been added to the model list yet + unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials) + + # Get custom model configurations + custom_model_configurations = self._get_custom_model_configurations( + tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials + ) + + can_added_models = [ + UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models + ] + + return CustomConfiguration( + provider=custom_provider_configuration, + models=custom_model_configurations, + can_added_models=can_added_models, + ) + + def _get_custom_provider_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> CustomProviderConfiguration | None: + """Get custom provider configuration.""" + # Find custom provider record (non-system) + custom_provider_record = next( + (record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None + ) + + if not custom_provider_record: + return None + # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas @@ -630,113 +687,98 @@ class ProviderManager: else [] ) - # Get custom provider record - custom_provider_record = None - for provider_record in provider_records: - if provider_record.provider_type == ProviderType.SYSTEM.value: - continue + # Get and decrypt provider credentials + provider_credentials = self._get_and_decrypt_credentials( + tenant_id=tenant_id, + record_id=custom_provider_record.id, + encrypted_config=custom_provider_record.encrypted_config, + secret_variables=provider_credential_secret_variables, + cache_type=ProviderCredentialsCacheType.PROVIDER, + is_provider=True, + ) - custom_provider_record = provider_record + return CustomProviderConfiguration( + credentials=provider_credentials, + current_credential_name=custom_provider_record.credential_name, + current_credential_id=custom_provider_record.credential_id, + available_credentials=self.get_provider_available_credentials( + tenant_id, custom_provider_record.provider_name + ), + ) - # Get custom provider credentials - custom_provider_configuration = None - if custom_provider_record: - provider_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER, - ) + def _get_can_added_models( + self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential] + ) -> list[dict]: + """Get the custom models and credentials from enterprise version which haven't add to the model list""" + existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records} - # Get cached provider credentials - cached_provider_credentials = provider_credentials_cache.get() + # Get not added custom models credentials + not_added_custom_models_credentials = [ + credential + for credential in all_model_credentials + if (credential.model_name, credential.model_type) not in existing_model_set + ] - if not cached_provider_credentials: - try: - # fix origin data - if custom_provider_record.encrypted_config is None: - provider_credentials = {} - elif not custom_provider_record.encrypted_config.startswith("{"): - provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} - else: - provider_credentials = json.loads(custom_provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + # Group credentials by model + model_to_credentials = defaultdict(list) + for credential in not_added_custom_models_credentials: + model_to_credentials[(credential.model_name, credential.model_type)].append(credential) - # Get decoding rsa key and cipher for decrypting credentials - if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + return [ + { + "model": model_key[0], + "model_type": ModelType.value_of(model_key[1]), + "available_model_credentials": [ + CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) + for cred in creds + ], + } + for model_key, creds in model_to_credentials.items() + ] - for variable in provider_credential_secret_variables: - if variable in provider_credentials: - with contextlib.suppress(ValueError): - provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable) or "", # type: ignore - self.decoding_rsa_key, - self.decoding_cipher_rsa, - ) - - # cache provider credentials - provider_credentials_cache.set(credentials=provider_credentials) - else: - provider_credentials = cached_provider_credentials - - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials, - current_credential_name=custom_provider_record.credential_name, - current_credential_id=custom_provider_record.credential_id, - available_credentials=self.get_provider_available_credentials( - tenant_id, custom_provider_record.provider_name - ), - ) - - # Get provider model credential secret variables + def _get_custom_model_configurations( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_model_records: list[ProviderModel], + can_added_models: list[dict], + all_model_credentials: Sequence[ProviderModelCredential], + ) -> list[CustomModelConfiguration]: + """Get custom model configurations.""" + # Get model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas if provider_entity.model_credential_schema else [] ) - # Get custom provider model credentials + # Create credentials lookup for efficient access + credentials_map = defaultdict(list) + for credential in all_model_credentials: + credentials_map[(credential.model_name, credential.model_type)].append(credential) + custom_model_configurations = [] + + # Process existing model records for provider_model_record in provider_model_records: - available_model_credentials = self.get_provider_model_available_credentials( - tenant_id, - provider_model_record.provider_name, - provider_model_record.model_name, - provider_model_record.model_type, + # Use pre-fetched credentials instead of individual database calls + available_model_credentials = [ + CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) + for cred in credentials_map.get( + (provider_model_record.model_name, provider_model_record.model_type), [] + ) + ] + + # Get and decrypt model credentials + provider_model_credentials = self._get_and_decrypt_credentials( + tenant_id=tenant_id, + record_id=provider_model_record.id, + encrypted_config=provider_model_record.encrypted_config, + secret_variables=model_credential_secret_variables, + cache_type=ProviderCredentialsCacheType.MODEL, + is_provider=False, ) - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL - ) - - # Get cached provider model credentials - cached_provider_model_credentials = provider_model_credentials_cache.get() - - if not cached_provider_model_credentials and provider_model_record.encrypted_config: - try: - provider_model_credentials = json.loads(provider_model_record.encrypted_config) - except JSONDecodeError: - continue - - # Get decoding rsa key and cipher for decrypting credentials - if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: - self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - - for variable in model_credential_secret_variables: - if variable in provider_model_credentials: - with contextlib.suppress(ValueError): - provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_model_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa, - ) - - # cache provider model credentials - provider_model_credentials_cache.set(credentials=provider_model_credentials) - else: - provider_model_credentials = cached_provider_model_credentials - custom_model_configurations.append( CustomModelConfiguration( model=provider_model_record.model_name, @@ -748,7 +790,71 @@ class ProviderManager: ) ) - return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) + # Add models that can be added + for model in can_added_models: + custom_model_configurations.append( + CustomModelConfiguration( + model=model["model"], + model_type=model["model_type"], + credentials=None, + current_credential_id=None, + current_credential_name=None, + available_model_credentials=model["available_model_credentials"], + unadded_to_model_list=True, + ) + ) + + return custom_model_configurations + + def _get_and_decrypt_credentials( + self, + tenant_id: str, + record_id: str, + encrypted_config: str | None, + secret_variables: list[str], + cache_type: ProviderCredentialsCacheType, + is_provider: bool = False, + ) -> dict: + """Get and decrypt credentials with caching.""" + credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=record_id, + cache_type=cache_type, + ) + + # Try to get from cache first + cached_credentials = credentials_cache.get() + if cached_credentials: + return cached_credentials + + # Parse encrypted config + if not encrypted_config: + return {} + + if is_provider and not encrypted_config.startswith("{"): + return {"openai_api_key": encrypted_config} + + try: + credentials = cast(dict, json.loads(encrypted_config)) + except JSONDecodeError: + return {} + + # Decrypt secret variables + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + for variable in secret_variables: + if variable in credentials: + with contextlib.suppress(ValueError): + credentials[variable] = encrypter.decrypt_token_with_decoding( + credentials.get(variable) or "", + self.decoding_rsa_key, + self.decoding_cipher_rsa, + ) + + # Cache the decrypted credentials + credentials_cache.set(credentials=credentials) + return credentials def _to_system_configuration( self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] @@ -956,18 +1062,6 @@ class ProviderManager: load_balancing_model_config.model_name == provider_model_setting.model_name and load_balancing_model_config.model_type == provider_model_setting.model_type ): - if load_balancing_model_config.name == "__delete__": - # to calculate current model whether has invalidate lb configs - load_balancing_configs.append( - ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={}, - credential_source_type=load_balancing_model_config.credential_source_type, - ) - ) - continue - if not load_balancing_model_config.enabled: continue @@ -1033,6 +1127,7 @@ class ProviderManager: model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, + load_balancing_enabled=provider_model_setting.load_balancing_enabled, load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 056decda269..1fe259dd468 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -13,6 +13,7 @@ from core.entities.provider_entities import ( CustomModelConfiguration, ProviderQuotaType, QuotaConfiguration, + UnaddedModelConfiguration, ) from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType @@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel): current_credential_name: Optional[str] = None available_credentials: Optional[list[CredentialConfiguration]] = None custom_models: Optional[list[CustomModelConfiguration]] = None + can_added_models: Optional[list[UnaddedModelConfiguration]] = None class SystemConfigurationResponse(BaseModel): diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index d830034f114..17696f5cd85 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,6 +3,8 @@ import logging from json import JSONDecodeError from typing import Optional, Union +from sqlalchemy import or_ + from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter @@ -69,7 +71,7 @@ class ModelLoadBalancingService: provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def get_load_balancing_configs( - self, tenant_id: str, provider: str, model: str, model_type: str + self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = "" ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. @@ -100,6 +102,11 @@ class ModelLoadBalancingService: if provider_model_setting and provider_model_setting.load_balancing_enabled: is_load_balancing_enabled = True + if config_from == "predefined-model": + credential_source_type = "provider" + else: + credential_source_type = "custom_model" + # Get load balancing configurations load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) @@ -108,6 +115,10 @@ class ModelLoadBalancingService: LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, + or_( + LoadBalancingModelConfig.credential_source_type == credential_source_type, + LoadBalancingModelConfig.credential_source_type.is_(None), + ), ) .order_by(LoadBalancingModelConfig.created_at) .all() @@ -405,7 +416,7 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name in {"__inherit__", "__delete__"}: + if name == "__inherit__": raise ValueError("Invalid load balancing config name") if credential_id: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 9e9422f9f71..69c7e4cf588 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -72,6 +72,7 @@ class ModelProviderService: provider_config = provider_configuration.custom_configuration.provider model_config = provider_configuration.custom_configuration.models + can_added_models = provider_configuration.custom_configuration.can_added_models provider_response = ProviderResponse( tenant_id=tenant_id, @@ -95,6 +96,7 @@ class ModelProviderService: current_credential_name=getattr(provider_config, "current_credential_name", None), available_credentials=getattr(provider_config, "available_credentials", []), custom_models=model_config, + can_added_models=can_added_models, ), system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, @@ -152,7 +154,7 @@ class ModelProviderService: provider_configuration.validate_provider_credentials(credentials) def create_provider_credential( - self, tenant_id: str, provider: str, credentials: dict, credential_name: str + self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None ) -> None: """ Create and save new provider credentials. @@ -172,7 +174,7 @@ class ModelProviderService: provider: str, credentials: dict, credential_id: str, - credential_name: str, + credential_name: str | None, ) -> None: """ update a saved provider credential (by credential_id). @@ -249,7 +251,7 @@ class ModelProviderService: ) def create_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None ) -> None: """ create and save model credentials. @@ -278,7 +280,7 @@ class ModelProviderService: model: str, credentials: dict, credential_id: str, - credential_name: str, + credential_name: str | None, ) -> None: """ update model credentials.