mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 19:21:05 +08:00
refactor: model_load_balancing_service and api_tools_manage_service (#34434)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -110,20 +110,21 @@ class ModelLoadBalancingService:
|
||||
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
load_balancing_configs = list(
|
||||
db.session.scalars(
|
||||
select(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
).all()
|
||||
)
|
||||
|
||||
if provider_configuration.custom_configuration.provider:
|
||||
@@ -143,7 +144,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
else:
|
||||
# move the inherit configuration to the first
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
inherit_config = load_balancing_configs.pop(i)
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
@@ -235,8 +236,8 @@ class ModelLoadBalancingService:
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
load_balancing_model_config = db.session.scalar(
|
||||
select(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
@@ -244,7 +245,7 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
@@ -351,26 +352,26 @@ class ModelLoadBalancingService:
|
||||
|
||||
if credential_id:
|
||||
if config_from == "predefined-model":
|
||||
credential_record = (
|
||||
db.session.query(ProviderCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
credential_record = db.session.scalar(
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name == provider_configuration.provider.provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
credential_record = (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type_enum,
|
||||
credential_record = db.session.scalar(
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_configuration.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type_enum,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not credential_record:
|
||||
raise ValueError(f"Provider credential with id {credential_id} not found")
|
||||
@@ -510,8 +511,8 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
# Get load balancing config
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
load_balancing_model_config = db.session.scalar(
|
||||
select(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
@@ -519,7 +520,7 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
|
||||
@@ -124,13 +124,13 @@ class ApiToolManageService:
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
@@ -215,13 +215,13 @@ class ApiToolManageService:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -259,13 +259,13 @@ class ApiToolManageService:
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -328,13 +328,13 @@ class ApiToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -378,13 +378,13 @@ class ApiToolManageService:
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
db_provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
|
||||
Reference in New Issue
Block a user