mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 05:09:19 +08:00
refactor: select in dataset_service (DatasetService class) (#34525)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy import exists, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -114,9 +114,11 @@ class DatasetService:
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
dataset_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
)
|
||||
dataset_permission = db.session.scalars(
|
||||
select(DatasetPermission).where(
|
||||
DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id
|
||||
)
|
||||
).all()
|
||||
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
||||
|
||||
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
||||
@@ -182,13 +184,12 @@ class DatasetService:
|
||||
@staticmethod
|
||||
def get_process_rules(dataset_id):
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.execute(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
@@ -225,7 +226,7 @@ class DatasetService:
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@@ -300,17 +301,17 @@ class DatasetService:
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
@@ -344,7 +345,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset: Dataset | None = db.session.get(Dataset, dataset_id)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@@ -466,14 +467,14 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return dataset is not None
|
||||
|
||||
@@ -596,7 +597,7 @@ class DatasetService:
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
|
||||
# Update dataset in database
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data))
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
@@ -631,7 +632,7 @@ class DatasetService:
|
||||
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, dataset.pipeline_id)
|
||||
if not pipeline:
|
||||
return
|
||||
|
||||
@@ -1138,8 +1139,10 @@ class DatasetService:
|
||||
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
# For partial team permission, user needs explicit permission or be the creator
|
||||
if dataset.created_by != user.id:
|
||||
user_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
user_permission = db.session.scalar(
|
||||
select(DatasetPermission)
|
||||
.where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not user_permission:
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
@@ -1161,7 +1164,9 @@ class DatasetService:
|
||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
if not any(
|
||||
dp.dataset_id == dataset.id
|
||||
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
|
||||
for dp in db.session.scalars(
|
||||
select(DatasetPermission).where(DatasetPermission.account_id == user.id)
|
||||
).all()
|
||||
):
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@@ -1175,12 +1180,11 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
return (
|
||||
db.session.query(AppDatasetJoin)
|
||||
return db.session.scalars(
|
||||
select(AppDatasetJoin)
|
||||
.where(AppDatasetJoin.dataset_id == dataset_id)
|
||||
.order_by(db.desc(AppDatasetJoin.created_at))
|
||||
.all()
|
||||
)
|
||||
.order_by(AppDatasetJoin.created_at.desc())
|
||||
).all()
|
||||
|
||||
@staticmethod
|
||||
def update_dataset_api_status(dataset_id: str, status: bool):
|
||||
|
||||
@@ -62,7 +62,7 @@ class TestDatasetServiceQueries:
|
||||
self, mock_dataset_query_dependencies
|
||||
):
|
||||
user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR)
|
||||
mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = []
|
||||
mock_dataset_query_dependencies["db"].session.scalars.return_value.all.return_value = []
|
||||
|
||||
items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user)
|
||||
|
||||
@@ -108,9 +108,7 @@ class TestDatasetServiceQueries:
|
||||
dataset_process_rule.rules_dict = {"delimiter": "\n"}
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
(
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
|
||||
) = dataset_process_rule
|
||||
(mock_db.session.execute.return_value.scalar_one_or_none.return_value) = dataset_process_rule
|
||||
|
||||
result = DatasetService.get_process_rules("dataset-1")
|
||||
|
||||
@@ -118,9 +116,7 @@ class TestDatasetServiceQueries:
|
||||
|
||||
def test_get_process_rules_falls_back_to_default_rules_when_missing(self):
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
(
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
|
||||
) = None
|
||||
(mock_db.session.execute.return_value.scalar_one_or_none.return_value) = None
|
||||
|
||||
result = DatasetService.get_process_rules("dataset-1")
|
||||
|
||||
@@ -151,7 +147,7 @@ class TestDatasetServiceQueries:
|
||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset
|
||||
mock_db.session.get.return_value = dataset
|
||||
|
||||
result = DatasetService.get_dataset(dataset.id)
|
||||
|
||||
@@ -308,7 +304,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
account = SimpleNamespace(id="user-1")
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
mock_db.session.scalar.return_value = object()
|
||||
|
||||
with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"):
|
||||
DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account)
|
||||
@@ -319,6 +315,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select"),
|
||||
patch(
|
||||
"services.dataset_service.Dataset",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
|
||||
@@ -326,7 +323,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
|
||||
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
@@ -355,6 +352,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select"),
|
||||
patch(
|
||||
"services.dataset_service.Dataset",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
|
||||
@@ -368,7 +366,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
||||
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
@@ -412,7 +410,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
mock_db.session.scalar.return_value = object()
|
||||
|
||||
with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
|
||||
@@ -435,12 +433,13 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select"),
|
||||
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
|
||||
patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name,
|
||||
patch("services.dataset_service.Pipeline", side_effect=pipeline_factory),
|
||||
patch("services.dataset_service.Dataset", side_effect=dataset_factory),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.all.return_value = [
|
||||
mock_db.session.scalars.return_value.all.return_value = [
|
||||
SimpleNamespace(name="Untitled"),
|
||||
SimpleNamespace(name="Untitled 1"),
|
||||
]
|
||||
@@ -465,7 +464,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
|
||||
@@ -520,7 +519,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
def test_has_dataset_same_name_returns_true_when_query_matches(self):
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
mock_db.session.scalar.return_value = object()
|
||||
|
||||
result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset")
|
||||
|
||||
@@ -630,7 +629,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user)
|
||||
|
||||
assert result is dataset
|
||||
updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0]
|
||||
updated_values = mock_db.session.execute.call_args.args[0].compile().params
|
||||
assert updated_values["name"] == "Updated Dataset"
|
||||
assert updated_values["description"] is None
|
||||
assert updated_values["retrieval_model"] == {"top_k": 4}
|
||||
@@ -658,13 +657,13 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
||||
|
||||
mock_db.session.query.assert_not_called()
|
||||
mock_db.session.get.assert_not_called()
|
||||
|
||||
def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self):
|
||||
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
||||
|
||||
@@ -703,7 +702,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
|
||||
patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new,
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
|
||||
mock_db.session.get.return_value = pipeline
|
||||
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
||||
|
||||
@@ -725,7 +724,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
|
||||
mock_db.session.get.return_value = pipeline
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
||||
@@ -1364,7 +1363,7 @@ class TestDatasetServicePermissionsAndLifecycle:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
@@ -1382,7 +1381,7 @@ class TestDatasetServicePermissionsAndLifecycle:
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
mock_db.session.query.assert_not_called()
|
||||
mock_db.session.scalar.assert_not_called()
|
||||
|
||||
def test_check_dataset_permission_allows_partial_team_member_with_binding(self):
|
||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
||||
@@ -1395,7 +1394,7 @@ class TestDatasetServicePermissionsAndLifecycle:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
mock_db.session.scalar.return_value = object()
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
@@ -1427,7 +1426,7 @@ class TestDatasetServicePermissionsAndLifecycle:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.filter_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
@@ -1446,9 +1445,7 @@ class TestDatasetServicePermissionsAndLifecycle:
|
||||
def test_get_related_apps_returns_ordered_query_results(self):
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.desc.side_effect = lambda column: column
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
|
||||
"relation-1"
|
||||
]
|
||||
mock_db.session.scalars.return_value.all.return_value = ["relation-1"]
|
||||
|
||||
result = DatasetService.get_related_apps("dataset-1")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user