diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 06f83a18f7a..faa978afdcd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any from graphon.model_runtime.entities.provider_entities import FormType +from sqlalchemy import func, select from sqlalchemy.orm import Session from configs import dify_config @@ -367,16 +368,16 @@ class DatasourceProviderService: check if tenant oauth params is enabled """ return ( - db.session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - enabled=True, + db.session.scalar( + select(func.count(DatasourceOauthTenantParamConfig.id)).where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, + DatasourceOauthTenantParamConfig.enabled == True, + ) ) - .count() - > 0 - ) + or 0 + ) > 0 def get_tenant_oauth_client( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False @@ -384,14 +385,14 @@ class DatasourceProviderService: """ get tenant oauth client """ - tenant_oauth_client_params = ( - db.session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + tenant_oauth_client_params = db.session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) @@ -707,24 +708,27 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = ( - db.session.query(DatasourceProvider) + datasource_providers: list[DatasourceProvider] = list( + db.session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ).all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + default_provider = db.session.execute( + select(DatasourceProvider.id) .where( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, ) - .all() - ) - if not datasource_providers: - return [] - copy_credentials_list = [] - default_provider = ( - db.session.query(DatasourceProvider.id) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() - ) + ).first() default_provider_id = default_provider.id if default_provider else None for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials @@ -880,14 +884,14 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = ( - db.session.query(DatasourceProvider) - .where( - DatasourceProvider.tenant_id == tenant_id, - DatasourceProvider.provider == provider, - DatasourceProvider.plugin_id == plugin_id, - ) - .all() + datasource_providers: list[DatasourceProvider] = list( + db.session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ).all() ) if not datasource_providers: return [] @@ -987,10 +991,15 @@ class DatasourceProviderService: :param plugin_id: plugin id :return: """ - datasource_provider = ( - db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) - .first() + datasource_provider = db.session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == auth_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .limit(1) ) if datasource_provider: db.session.delete(datasource_provider) diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index da414816ff1..bc4120e2af4 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -57,6 +57,10 @@ class TestDatasourceProviderService: q.count.return_value = 0 q.delete.return_value = 1 + # Default values for select()-style calls (tests override per-case) + sess.scalar.return_value = None + sess.scalars.return_value.all.return_value = [] + mock_cls.return_value.__enter__.return_value = sess mock_cls.return_value.no_autoflush.__enter__.return_value = sess @@ -183,11 +187,11 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session): - mock_db_session.query().count.return_value = 1 + mock_db_session.scalar.return_value = 1 assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False # ----------------------------------------------------------------------- @@ -401,7 +405,7 @@ class TestDatasourceProviderService: def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session): tenant_params = MagicMock() tenant_params.client_params = {"k": "v"} - mock_db_session.query().first.return_value = tenant_params + mock_db_session.scalar.return_value = tenant_params with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_tenant_oauth_client("t1", make_id(), mask=True) assert result == {"k": "mask"} @@ -409,13 +413,13 @@ class TestDatasourceProviderService: def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session): tenant_params = MagicMock() tenant_params.client_params = {"k": "v"} - mock_db_session.query().first.return_value = tenant_params + mock_db_session.scalar.return_value = tenant_params with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_tenant_oauth_client("t1", make_id(), mask=False) assert result == {"k": "dec"} def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None assert service.get_tenant_oauth_client("t1", make_id()) is None # ----------------------------------------------------------------------- @@ -616,7 +620,7 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session): - mock_db_session.query().all.return_value = [] + mock_db_session.scalars.return_value.all.return_value = [] assert service.list_datasource_credentials("t1", "prov", "org/plug") == [] def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session): @@ -624,7 +628,7 @@ class TestDatasourceProviderService: p.auth_type = "api_key" p.encrypted_credentials = {"sk": "v"} p.is_default = False - mock_db_session.query().all.return_value = [p] + mock_db_session.scalars.return_value.all.return_value = [p] with patch.object(service, "extract_secret_variables", return_value=["sk"]): result = service.list_datasource_credentials("t1", "prov", "org/plug") assert len(result) == 1 @@ -676,14 +680,14 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session): - mock_db_session.query().all.return_value = [] + mock_db_session.scalars.return_value.all.return_value = [] assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == [] def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" p.encrypted_credentials = {"sk": "v"} - mock_db_session.query().all.return_value = [p] + mock_db_session.scalars.return_value.all.return_value = [p] with patch.object(service, "extract_secret_variables", return_value=["sk"]): result = service.get_real_datasource_credentials("t1", "prov", "org/plug") assert len(result) == 1 @@ -751,13 +755,13 @@ class TestDatasourceProviderService: def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p service.remove_datasource_credentials("t1", "id", "prov", "org/plug") mock_db_session.delete.assert_called_once_with(p) mock_db_session.commit.assert_called_once() def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): """No error raised; no delete called when record doesn't exist (lines 994 branch).""" - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None service.remove_datasource_credentials("t1", "id", "prov", "org/plug") mock_db_session.delete.assert_not_called()