refactor: model_load_balancing_service and api_tools_manage_service (#34434)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-02 06:38:35 +02:00
committed by GitHub
parent f9d9ad7a38
commit 399d3f8da5
3 changed files with 61 additions and 60 deletions

View File

@@ -110,20 +110,21 @@ class ModelLoadBalancingService:
credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
load_balancing_configs = list(
db.session.scalars(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
).all()
)
if provider_configuration.custom_configuration.provider:
@@ -143,7 +144,7 @@ class ModelLoadBalancingService:
load_balancing_configs.insert(0, inherit_config)
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
@@ -235,8 +236,8 @@ class ModelLoadBalancingService:
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
load_balancing_model_config = db.session.scalar(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
@@ -244,7 +245,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
.limit(1)
)
if not load_balancing_model_config:
@@ -351,26 +352,26 @@ class ModelLoadBalancingService:
if credential_id:
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
credential_record = db.session.scalar(
select(ProviderCredential)
.where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name == provider_configuration.provider.provider,
)
.first()
.limit(1)
)
else:
credential_record = (
db.session.query(ProviderModelCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum,
credential_record = db.session.scalar(
select(ProviderModelCredential)
.where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_configuration.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type_enum,
)
.first()
.limit(1)
)
if not credential_record:
raise ValueError(f"Provider credential with id {credential_id} not found")
@@ -510,8 +511,8 @@ class ModelLoadBalancingService:
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
load_balancing_model_config = db.session.scalar(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
@@ -519,7 +520,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
.limit(1)
)
if not load_balancing_model_config:

View File

@@ -124,13 +124,13 @@ class ApiToolManageService:
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if provider is not None:
@@ -215,13 +215,13 @@ class ApiToolManageService:
"""
list api tool provider tools
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if provider is None:
@@ -259,13 +259,13 @@ class ApiToolManageService:
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
.limit(1)
)
if provider is None:
@@ -328,13 +328,13 @@ class ApiToolManageService:
"""
delete tool provider
"""
provider = (
db.session.query(ApiToolProvider)
provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if provider is None:
@@ -378,13 +378,13 @@ class ApiToolManageService:
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
db_provider = (
db.session.query(ApiToolProvider)
db_provider = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
.limit(1)
)
if not db_provider:

View File

@@ -158,7 +158,7 @@ def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_fo
credential_id="cred-1",
enabled=True,
)
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
mock_db.session.scalars.return_value.all.return_value = [config]
mocker.patch(
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
return_value=("rsa", "cipher"),
@@ -216,7 +216,7 @@ def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate
credential_id=None,
enabled=False,
)
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
mock_db.session.scalars.return_value.all.return_value = [
normal_config,
inherit_config,
]
@@ -269,7 +269,7 @@ def test_get_load_balancing_config_should_return_none_when_config_not_found(
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
# Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
@@ -289,7 +289,7 @@ def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_
}
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
mock_db.session.query.return_value.where.return_value.first.return_value = config
mock_db.session.scalar.return_value = config
# Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
@@ -389,7 +389,7 @@ def test_update_load_balancing_configs_should_raise_value_error_when_credential_
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
@@ -578,7 +578,7 @@ def test_update_load_balancing_configs_should_create_from_existing_provider_cred
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
mock_db.session.scalar.return_value = credential_record
# Act
service.update_load_balancing_configs(
@@ -623,7 +623,7 @@ def test_validate_load_balancing_credentials_should_raise_value_error_when_confi
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
@@ -646,7 +646,7 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config = SimpleNamespace(id="cfg-1")
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
mock_db.session.scalar.return_value = existing_config
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
# Act