refactor: select in datasource_provider_service (#34548)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-04 19:12:15 -05:00
committed by GitHub
parent c2428361c4
commit 779e6b8e0b
2 changed files with 64 additions and 51 deletions

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -367,16 +368,16 @@ class DatasourceProviderService:
check if tenant oauth params is enabled
"""
return (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
db.session.scalar(
select(func.count(DatasourceOauthTenantParamConfig.id)).where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
DatasourceOauthTenantParamConfig.enabled == True,
)
)
.count()
> 0
)
or 0
) > 0
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
@@ -384,14 +385,14 @@ class DatasourceProviderService:
"""
get tenant oauth client
"""
tenant_oauth_client_params = (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
tenant_oauth_client_params = db.session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
@@ -707,24 +708,27 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
datasource_providers: list[DatasourceProvider] = list(
db.session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
).all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = db.session.execute(
select(DatasourceProvider.id)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = (
db.session.query(DatasourceProvider.id)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
).first()
default_provider_id = default_provider.id if default_provider else None
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
@@ -880,14 +884,14 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
datasource_providers: list[DatasourceProvider] = list(
db.session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
).all()
)
if not datasource_providers:
return []
@@ -987,10 +991,15 @@ class DatasourceProviderService:
:param plugin_id: plugin id
:return:
"""
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
datasource_provider = db.session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == auth_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.limit(1)
)
if datasource_provider:
db.session.delete(datasource_provider)

View File

@@ -57,6 +57,10 @@ class TestDatasourceProviderService:
q.count.return_value = 0
q.delete.return_value = 1
# Default values for select()-style calls (tests override per-case)
sess.scalar.return_value = None
sess.scalars.return_value.all.return_value = []
mock_cls.return_value.__enter__.return_value = sess
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
@@ -183,11 +187,11 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session):
mock_db_session.query().count.return_value = 1
mock_db_session.scalar.return_value = 1
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True
def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False
# -----------------------------------------------------------------------
@@ -401,7 +405,7 @@ class TestDatasourceProviderService:
def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session):
tenant_params = MagicMock()
tenant_params.client_params = {"k": "v"}
mock_db_session.query().first.return_value = tenant_params
mock_db_session.scalar.return_value = tenant_params
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_tenant_oauth_client("t1", make_id(), mask=True)
assert result == {"k": "mask"}
@@ -409,13 +413,13 @@ class TestDatasourceProviderService:
def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session):
tenant_params = MagicMock()
tenant_params.client_params = {"k": "v"}
mock_db_session.query().first.return_value = tenant_params
mock_db_session.scalar.return_value = tenant_params
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_tenant_oauth_client("t1", make_id(), mask=False)
assert result == {"k": "dec"}
def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
assert service.get_tenant_oauth_client("t1", make_id()) is None
# -----------------------------------------------------------------------
@@ -616,7 +620,7 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session):
mock_db_session.query().all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
assert service.list_datasource_credentials("t1", "prov", "org/plug") == []
def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session):
@@ -624,7 +628,7 @@ class TestDatasourceProviderService:
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "v"}
p.is_default = False
mock_db_session.query().all.return_value = [p]
mock_db_session.scalars.return_value.all.return_value = [p]
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
result = service.list_datasource_credentials("t1", "prov", "org/plug")
assert len(result) == 1
@@ -676,14 +680,14 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session):
mock_db_session.query().all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == []
def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "v"}
mock_db_session.query().all.return_value = [p]
mock_db_session.scalars.return_value.all.return_value = [p]
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
result = service.get_real_datasource_credentials("t1", "prov", "org/plug")
assert len(result) == 1
@@ -751,13 +755,13 @@ class TestDatasourceProviderService:
def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_called_once_with(p)
mock_db_session.commit.assert_called_once()
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_not_called()