From 399d3f8da57f31af4fb20caa4d38c92c7c3cc172 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 2 Apr 2026 06:38:35 +0200 Subject: [PATCH] refactor: model_load_balancing_service and api_tools_manage_service (#34434) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/model_load_balancing_service.py | 75 ++++++++++--------- .../tools/api_tools_manage_service.py | 30 ++++---- .../test_model_load_balancing_service.py | 16 ++-- 3 files changed, 61 insertions(+), 60 deletions(-) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 752d3002d9a..bc0bfd215cf 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -110,20 +110,21 @@ class ModelLoadBalancingService: credential_source_type = CredentialSourceType.CUSTOM_MODEL # Get load balancing configurations - load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig) - .where( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type_enum, - LoadBalancingModelConfig.model_name == model, - or_( - LoadBalancingModelConfig.credential_source_type == credential_source_type, - LoadBalancingModelConfig.credential_source_type.is_(None), - ), - ) - .order_by(LoadBalancingModelConfig.created_at) - .all() + load_balancing_configs = list( + db.session.scalars( + select(LoadBalancingModelConfig) + .where( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type_enum, + LoadBalancingModelConfig.model_name == model, + or_( + LoadBalancingModelConfig.credential_source_type == credential_source_type, + LoadBalancingModelConfig.credential_source_type.is_(None), + ), + ) + .order_by(LoadBalancingModelConfig.created_at) + ).all() ) if provider_configuration.custom_configuration.provider: @@ -143,7 +144,7 @@ class ModelLoadBalancingService: load_balancing_configs.insert(0, inherit_config) else: # move the inherit configuration to the first - for i, load_balancing_config in enumerate(load_balancing_configs[:]): + for i, load_balancing_config in enumerate(load_balancing_configs.copy()): if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) @@ -235,8 +236,8 @@ class ModelLoadBalancingService: model_type_enum = ModelType.value_of(model_type) # Get load balancing configurations - load_balancing_model_config = ( - db.session.query(LoadBalancingModelConfig) + load_balancing_model_config = db.session.scalar( + select(LoadBalancingModelConfig) .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, @@ -244,7 +245,7 @@ class ModelLoadBalancingService: LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) - .first() + .limit(1) ) if not load_balancing_model_config: @@ -351,26 +352,26 @@ class ModelLoadBalancingService: if credential_id: if config_from == "predefined-model": - credential_record = ( - db.session.query(ProviderCredential) - .filter_by( - id=credential_id, - tenant_id=tenant_id, - provider_name=provider_configuration.provider.provider, + credential_record = db.session.scalar( + select(ProviderCredential) + .where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == tenant_id, + ProviderCredential.provider_name == provider_configuration.provider.provider, ) - .first() + .limit(1) ) else: - credential_record = ( - db.session.query(ProviderModelCredential) - .filter_by( - id=credential_id, - tenant_id=tenant_id, - provider_name=provider_configuration.provider.provider, - model_name=model, - model_type=model_type_enum, + credential_record = db.session.scalar( + select(ProviderModelCredential) + .where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == tenant_id, + ProviderModelCredential.provider_name == provider_configuration.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type_enum, ) - .first() + .limit(1) ) if not credential_record: raise ValueError(f"Provider credential with id {credential_id} not found") @@ -510,8 +511,8 @@ class ModelLoadBalancingService: load_balancing_model_config = None if config_id: # Get load balancing config - load_balancing_model_config = ( - db.session.query(LoadBalancingModelConfig) + load_balancing_model_config = db.session.scalar( + select(LoadBalancingModelConfig) .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, @@ -519,7 +520,7 @@ class ModelLoadBalancingService: LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) - .first() + .limit(1) ) if not load_balancing_model_config: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 2a56bc0c71e..0a6968700f8 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -124,13 +124,13 @@ class ApiToolManageService: provider_name = provider_name.strip() # check if the provider exists - provider = ( - db.session.query(ApiToolProvider) + provider = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) - .first() + .limit(1) ) if provider is not None: @@ -215,13 +215,13 @@ class ApiToolManageService: """ list api tool provider tools """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) - .first() + .limit(1) ) if provider is None: @@ -259,13 +259,13 @@ class ApiToolManageService: provider_name = provider_name.strip() # check if the provider exists - provider = ( - db.session.query(ApiToolProvider) + provider = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == original_provider, ) - .first() + .limit(1) ) if provider is None: @@ -328,13 +328,13 @@ class ApiToolManageService: """ delete tool provider """ - provider = ( - db.session.query(ApiToolProvider) + provider = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) - .first() + .limit(1) ) if provider is None: @@ -378,13 +378,13 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = ( - db.session.query(ApiToolProvider) + db_provider = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) - .first() + .limit(1) ) if not db_provider: diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index f85f1ace16d..bea288fb9b2 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -158,7 +158,7 @@ def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_fo credential_id="cred-1", enabled=True, ) - mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config] + mock_db.session.scalars.return_value.all.return_value = [config] mocker.patch( "services.model_load_balancing_service.encrypter.get_decrypt_decoding", return_value=("rsa", "cipher"), @@ -216,7 +216,7 @@ def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate credential_id=None, enabled=False, ) - mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + mock_db.session.scalars.return_value.all.return_value = [ normal_config, inherit_config, ] @@ -269,7 +269,7 @@ def test_get_load_balancing_config_should_return_none_when_config_not_found( # Arrange provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") @@ -289,7 +289,7 @@ def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_ } service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True) - mock_db.session.query.return_value.where.return_value.first.return_value = config + mock_db.session.scalar.return_value = config # Act result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") @@ -389,7 +389,7 @@ def test_update_load_balancing_configs_should_raise_value_error_when_credential_ provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} mock_db.session.scalars.return_value.all.return_value = [] - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"): @@ -578,7 +578,7 @@ def test_update_load_balancing_configs_should_create_from_existing_provider_cred service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} mock_db.session.scalars.return_value.all.return_value = [] credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}') - mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record + mock_db.session.scalar.return_value = credential_record # Act service.update_load_balancing_configs( @@ -623,7 +623,7 @@ def test_validate_load_balancing_credentials_should_raise_value_error_when_confi # Arrange provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"): @@ -646,7 +646,7 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_ provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} existing_config = SimpleNamespace(id="cfg-1") - mock_db.session.query.return_value.where.return_value.first.return_value = existing_config + mock_db.session.scalar.return_value = existing_config mock_validate = mocker.patch.object(service, "_custom_credentials_validate") # Act