refactor: select in dataset_service (DatasetService class) (#34525)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-03 14:01:31 +02:00
committed by GitHub
parent 06dde4f503
commit e85d9a0d72
2 changed files with 56 additions and 55 deletions

View File

@@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy import exists, func, select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
@@ -114,9 +114,11 @@ class DatasetService:
if user:
# get permitted dataset ids
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
dataset_permission = db.session.scalars(
select(DatasetPermission).where(
DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id
)
).all()
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@@ -182,13 +184,12 @@ class DatasetService:
@staticmethod
def get_process_rules(dataset_id):
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
dataset_process_rule = db.session.execute(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
).scalar_one_or_none()
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
@@ -225,7 +226,7 @@ class DatasetService:
summary_index_setting: dict | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@@ -300,17 +301,17 @@ class DatasetService:
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
if db.session.scalar(
select(Dataset)
.where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id)
.limit(1)
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
else:
# generate a random name as Untitled 1 2 3 ...
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
@@ -344,7 +345,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
dataset: Dataset | None = db.session.get(Dataset, dataset_id)
return dataset
@staticmethod
@@ -466,14 +467,14 @@ class DatasetService:
@staticmethod
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
dataset = (
db.session.query(Dataset)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id != dataset_id,
Dataset.name == name,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
return dataset is not None
@@ -596,7 +597,7 @@ class DatasetService:
filtered_data["icon_info"] = data.get("icon_info")
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data))
db.session.commit()
# Reload dataset to get updated values
@@ -631,7 +632,7 @@ class DatasetService:
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
pipeline = db.session.get(Pipeline, dataset.pipeline_id)
if not pipeline:
return
@@ -1138,8 +1139,10 @@ class DatasetService:
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
# For partial team permission, user needs explicit permission or be the creator
if dataset.created_by != user.id:
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
user_permission = db.session.scalar(
select(DatasetPermission)
.where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id)
.limit(1)
)
if not user_permission:
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
@@ -1161,7 +1164,9 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
for dp in db.session.scalars(
select(DatasetPermission).where(DatasetPermission.account_id == user.id)
).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@@ -1175,12 +1180,11 @@ class DatasetService:
@staticmethod
def get_related_apps(dataset_id: str):
return (
db.session.query(AppDatasetJoin)
return db.session.scalars(
select(AppDatasetJoin)
.where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
.order_by(AppDatasetJoin.created_at.desc())
).all()
@staticmethod
def update_dataset_api_status(dataset_id: str, status: bool):

View File

@@ -62,7 +62,7 @@ class TestDatasetServiceQueries:
self, mock_dataset_query_dependencies
):
user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR)
mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = []
mock_dataset_query_dependencies["db"].session.scalars.return_value.all.return_value = []
items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user)
@@ -108,9 +108,7 @@ class TestDatasetServiceQueries:
dataset_process_rule.rules_dict = {"delimiter": "\n"}
with patch("services.dataset_service.db") as mock_db:
(
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
) = dataset_process_rule
(mock_db.session.execute.return_value.scalar_one_or_none.return_value) = dataset_process_rule
result = DatasetService.get_process_rules("dataset-1")
@@ -118,9 +116,7 @@ class TestDatasetServiceQueries:
def test_get_process_rules_falls_back_to_default_rules_when_missing(self):
with patch("services.dataset_service.db") as mock_db:
(
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
) = None
(mock_db.session.execute.return_value.scalar_one_or_none.return_value) = None
result = DatasetService.get_process_rules("dataset-1")
@@ -151,7 +147,7 @@ class TestDatasetServiceQueries:
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset
mock_db.session.get.return_value = dataset
result = DatasetService.get_dataset(dataset.id)
@@ -308,7 +304,7 @@ class TestDatasetServiceCreationAndUpdate:
account = SimpleNamespace(id="user-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
mock_db.session.scalar.return_value = object()
with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"):
DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account)
@@ -319,6 +315,7 @@ class TestDatasetServiceCreationAndUpdate:
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch(
"services.dataset_service.Dataset",
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
@@ -326,7 +323,7 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
dataset = DatasetService.create_empty_dataset(
@@ -355,6 +352,7 @@ class TestDatasetServiceCreationAndUpdate:
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch(
"services.dataset_service.Dataset",
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
@@ -368,7 +366,7 @@ class TestDatasetServiceCreationAndUpdate:
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
dataset = DatasetService.create_empty_dataset(
@@ -412,7 +410,7 @@ class TestDatasetServiceCreationAndUpdate:
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
mock_db.session.scalar.return_value = object()
with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"):
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
@@ -435,12 +433,13 @@ class TestDatasetServiceCreationAndUpdate:
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name,
patch("services.dataset_service.Pipeline", side_effect=pipeline_factory),
patch("services.dataset_service.Dataset", side_effect=dataset_factory),
):
mock_db.session.query.return_value.filter_by.return_value.all.return_value = [
mock_db.session.scalars.return_value.all.return_value = [
SimpleNamespace(name="Untitled"),
SimpleNamespace(name="Untitled 1"),
]
@@ -465,7 +464,7 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
@@ -520,7 +519,7 @@ class TestDatasetServiceCreationAndUpdate:
def test_has_dataset_same_name_returns_true_when_query_matches(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = object()
mock_db.session.scalar.return_value = object()
result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset")
@@ -630,7 +629,7 @@ class TestDatasetServiceCreationAndUpdate:
result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user)
assert result is dataset
updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0]
updated_values = mock_db.session.execute.call_args.args[0].compile().params
assert updated_values["name"] == "Updated Dataset"
assert updated_values["description"] is None
assert updated_values["retrieval_model"] == {"top_k": 4}
@@ -658,13 +657,13 @@ class TestDatasetServiceCreationAndUpdate:
with patch("services.dataset_service.db") as mock_db:
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.query.assert_not_called()
mock_db.session.get.assert_not_called()
def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self):
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.get.return_value = None
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
@@ -703,7 +702,7 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
mock_db.session.get.return_value = pipeline
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
@@ -725,7 +724,7 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
mock_db.session.get.return_value = pipeline
with pytest.raises(RuntimeError, match="boom"):
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
@@ -1364,7 +1363,7 @@ class TestDatasetServicePermissionsAndLifecycle:
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, user)
@@ -1382,7 +1381,7 @@ class TestDatasetServicePermissionsAndLifecycle:
with patch("services.dataset_service.db") as mock_db:
DatasetService.check_dataset_permission(dataset, user)
mock_db.session.query.assert_not_called()
mock_db.session.scalar.assert_not_called()
def test_check_dataset_permission_allows_partial_team_member_with_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
@@ -1395,7 +1394,7 @@ class TestDatasetServicePermissionsAndLifecycle:
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
mock_db.session.scalar.return_value = object()
DatasetService.check_dataset_permission(dataset, user)
@@ -1427,7 +1426,7 @@ class TestDatasetServicePermissionsAndLifecycle:
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.all.return_value = []
mock_db.session.scalars.return_value.all.return_value = []
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
@@ -1446,9 +1445,7 @@ class TestDatasetServicePermissionsAndLifecycle:
def test_get_related_apps_returns_ordered_query_results(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.desc.side_effect = lambda column: column
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
"relation-1"
]
mock_db.session.scalars.return_value.all.return_value = ["relation-1"]
result = DatasetService.get_related_apps("dataset-1")