refactor: model_load_balancing_service and api_tools_manage_service (#34434)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-02 06:38:35 +02:00
committed by GitHub
parent f9d9ad7a38
commit 399d3f8da5
3 changed files with 61 additions and 60 deletions

View File

@@ -110,20 +110,21 @@ class ModelLoadBalancingService:
credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
load_balancing_configs = list(
db.session.scalars(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
).all()
)
if provider_configuration.custom_configuration.provider:
@@ -143,7 +144,7 @@ class ModelLoadBalancingService:
load_balancing_configs.insert(0, inherit_config)
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
@@ -235,8 +236,8 @@ class ModelLoadBalancingService:
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
load_balancing_model_config = db.session.scalar(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
@@ -244,7 +245,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
.limit(1)
)
if not load_balancing_model_config:
@@ -351,26 +352,26 @@ class ModelLoadBalancingService:
if credential_id:
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
credential_record = db.session.scalar(
select(ProviderCredential)
.where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name == provider_configuration.provider.provider,
)
.first()
.limit(1)
)
else:
credential_record = (
db.session.query(ProviderModelCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum,
credential_record = db.session.scalar(
select(ProviderModelCredential)
.where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_configuration.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type_enum,
)
.first()
.limit(1)
)
if not credential_record:
raise ValueError(f"Provider credential with id {credential_id} not found")
@@ -510,8 +511,8 @@ class ModelLoadBalancingService:
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
load_balancing_model_config = db.session.scalar(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
@@ -519,7 +520,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
.limit(1)
)
if not load_balancing_model_config:

View File

@@ -124,13 +124,13 @@ class ApiToolManageService:
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if provider is not None:
@@ -215,13 +215,13 @@ class ApiToolManageService:
"""
list api tool provider tools
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if provider is None:
@@ -259,13 +259,13 @@ class ApiToolManageService:
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
.limit(1)
)
if provider is None:
@@ -328,13 +328,13 @@ class ApiToolManageService:
"""
delete tool provider
"""
provider = (
db.session.query(ApiToolProvider)
provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if provider is None:
@@ -378,13 +378,13 @@ class ApiToolManageService:
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
db_provider = (
db.session.query(ApiToolProvider)
db_provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if not db_provider: