mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:03:14 +08:00
refactor: select in datasource_provider_service (#34548)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -367,16 +368,16 @@ class DatasourceProviderService:
|
||||
check if tenant oauth params is enabled
|
||||
"""
|
||||
return (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
enabled=True,
|
||||
db.session.scalar(
|
||||
select(func.count(DatasourceOauthTenantParamConfig.id)).where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
DatasourceOauthTenantParamConfig.enabled == True,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
def get_tenant_oauth_client(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
|
||||
@@ -384,14 +385,14 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get tenant oauth client
|
||||
"""
|
||||
tenant_oauth_client_params = (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
tenant_oauth_client_params = db.session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
@@ -707,24 +708,27 @@ class DatasourceProviderService:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
datasource_providers: list[DatasourceProvider] = list(
|
||||
db.session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
default_provider = db.session.execute(
|
||||
select(DatasourceProvider.id)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
default_provider = (
|
||||
db.session.query(DatasourceProvider.id)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
).first()
|
||||
default_provider_id = default_provider.id if default_provider else None
|
||||
for datasource_provider in datasource_providers:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
@@ -880,14 +884,14 @@ class DatasourceProviderService:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
datasource_providers: list[DatasourceProvider] = list(
|
||||
db.session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
@@ -987,10 +991,15 @@ class DatasourceProviderService:
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
datasource_provider = db.session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == auth_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if datasource_provider:
|
||||
db.session.delete(datasource_provider)
|
||||
|
||||
@@ -57,6 +57,10 @@ class TestDatasourceProviderService:
|
||||
q.count.return_value = 0
|
||||
q.delete.return_value = 1
|
||||
|
||||
# Default values for select()-style calls (tests override per-case)
|
||||
sess.scalar.return_value = None
|
||||
sess.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_cls.return_value.__enter__.return_value = sess
|
||||
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
||||
|
||||
@@ -183,11 +187,11 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 1
|
||||
mock_db_session.scalar.return_value = 1
|
||||
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True
|
||||
|
||||
def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -401,7 +405,7 @@ class TestDatasourceProviderService:
|
||||
def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session):
|
||||
tenant_params = MagicMock()
|
||||
tenant_params.client_params = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = tenant_params
|
||||
mock_db_session.scalar.return_value = tenant_params
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_tenant_oauth_client("t1", make_id(), mask=True)
|
||||
assert result == {"k": "mask"}
|
||||
@@ -409,13 +413,13 @@ class TestDatasourceProviderService:
|
||||
def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session):
|
||||
tenant_params = MagicMock()
|
||||
tenant_params.client_params = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = tenant_params
|
||||
mock_db_session.scalar.return_value = tenant_params
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_tenant_oauth_client("t1", make_id(), mask=False)
|
||||
assert result == {"k": "dec"}
|
||||
|
||||
def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
assert service.get_tenant_oauth_client("t1", make_id()) is None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -616,7 +620,7 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session):
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
assert service.list_datasource_credentials("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session):
|
||||
@@ -624,7 +628,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "v"}
|
||||
p.is_default = False
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
mock_db_session.scalars.return_value.all.return_value = [p]
|
||||
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
||||
result = service.list_datasource_credentials("t1", "prov", "org/plug")
|
||||
assert len(result) == 1
|
||||
@@ -676,14 +680,14 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session):
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "v"}
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
mock_db_session.scalars.return_value.all.return_value = [p]
|
||||
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
||||
result = service.get_real_datasource_credentials("t1", "prov", "org/plug")
|
||||
assert len(result) == 1
|
||||
@@ -751,13 +755,13 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_called_once_with(p)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
||||
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user