mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 05:09:19 +08:00
refactor: model_load_balancing_service and api_tools_manage_service (#34434)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -110,20 +110,21 @@ class ModelLoadBalancingService:
|
||||
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
load_balancing_configs = list(
|
||||
db.session.scalars(
|
||||
select(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
).all()
|
||||
)
|
||||
|
||||
if provider_configuration.custom_configuration.provider:
|
||||
@@ -143,7 +144,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
else:
|
||||
# move the inherit configuration to the first
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
inherit_config = load_balancing_configs.pop(i)
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
@@ -235,8 +236,8 @@ class ModelLoadBalancingService:
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
load_balancing_model_config = db.session.scalar(
|
||||
select(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
@@ -244,7 +245,7 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
@@ -351,26 +352,26 @@ class ModelLoadBalancingService:
|
||||
|
||||
if credential_id:
|
||||
if config_from == "predefined-model":
|
||||
credential_record = (
|
||||
db.session.query(ProviderCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
credential_record = db.session.scalar(
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name == provider_configuration.provider.provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
credential_record = (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type_enum,
|
||||
credential_record = db.session.scalar(
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_configuration.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type_enum,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not credential_record:
|
||||
raise ValueError(f"Provider credential with id {credential_id} not found")
|
||||
@@ -510,8 +511,8 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
# Get load balancing config
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
load_balancing_model_config = db.session.scalar(
|
||||
select(LoadBalancingModelConfig)
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
@@ -519,7 +520,7 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
|
||||
@@ -124,13 +124,13 @@ class ApiToolManageService:
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
@@ -215,13 +215,13 @@ class ApiToolManageService:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -259,13 +259,13 @@ class ApiToolManageService:
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -328,13 +328,13 @@ class ApiToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -378,13 +378,13 @@ class ApiToolManageService:
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
db_provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
|
||||
@@ -158,7 +158,7 @@ def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_fo
|
||||
credential_id="cred-1",
|
||||
enabled=True,
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
|
||||
mock_db.session.scalars.return_value.all.return_value = [config]
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
||||
return_value=("rsa", "cipher"),
|
||||
@@ -216,7 +216,7 @@ def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate
|
||||
credential_id=None,
|
||||
enabled=False,
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
|
||||
mock_db.session.scalars.return_value.all.return_value = [
|
||||
normal_config,
|
||||
inherit_config,
|
||||
]
|
||||
@@ -269,7 +269,7 @@ def test_get_load_balancing_config_should_return_none_when_config_not_found(
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
@@ -289,7 +289,7 @@ def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_
|
||||
}
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = config
|
||||
mock_db.session.scalar.return_value = config
|
||||
|
||||
# Act
|
||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||
@@ -389,7 +389,7 @@ def test_update_load_balancing_configs_should_raise_value_error_when_credential_
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
|
||||
@@ -578,7 +578,7 @@ def test_update_load_balancing_configs_should_create_from_existing_provider_cred
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
|
||||
mock_db.session.scalar.return_value = credential_record
|
||||
|
||||
# Act
|
||||
service.update_load_balancing_configs(
|
||||
@@ -623,7 +623,7 @@ def test_validate_load_balancing_credentials_should_raise_value_error_when_confi
|
||||
# Arrange
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
|
||||
@@ -646,7 +646,7 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_
|
||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
existing_config = SimpleNamespace(id="cfg-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
|
||||
mock_db.session.scalar.return_value = existing_config
|
||||
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
|
||||
|
||||
# Act
|
||||
|
||||
Reference in New Issue
Block a user