refactor: select in dataset_service (DatasetService class) (#34525)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-03 14:01:31 +02:00
committed by GitHub
parent 06dde4f503
commit e85d9a0d72
2 changed files with 56 additions and 55 deletions

View File

@@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy import exists, func, select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
@@ -114,9 +114,11 @@ class DatasetService:
if user:
# get permitted dataset ids
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
dataset_permission = db.session.scalars(
select(DatasetPermission).where(
DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id
)
).all()
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@@ -182,13 +184,12 @@ class DatasetService:
@staticmethod
def get_process_rules(dataset_id):
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
dataset_process_rule = db.session.execute(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
).scalar_one_or_none()
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
@@ -225,7 +226,7 @@ class DatasetService:
summary_index_setting: dict | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@@ -300,17 +301,17 @@ class DatasetService:
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
if db.session.scalar(
select(Dataset)
.where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id)
.limit(1)
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
else:
# generate a random name as Untitled 1 2 3 ...
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
@@ -344,7 +345,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
dataset: Dataset | None = db.session.get(Dataset, dataset_id)
return dataset
@staticmethod
@@ -466,14 +467,14 @@ class DatasetService:
@staticmethod
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
dataset = (
db.session.query(Dataset)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id != dataset_id,
Dataset.name == name,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
return dataset is not None
@@ -596,7 +597,7 @@ class DatasetService:
filtered_data["icon_info"] = data.get("icon_info")
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data))
db.session.commit()
# Reload dataset to get updated values
@@ -631,7 +632,7 @@ class DatasetService:
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
pipeline = db.session.get(Pipeline, dataset.pipeline_id)
if not pipeline:
return
@@ -1138,8 +1139,10 @@ class DatasetService:
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
# For partial team permission, user needs explicit permission or be the creator
if dataset.created_by != user.id:
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
user_permission = db.session.scalar(
select(DatasetPermission)
.where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id)
.limit(1)
)
if not user_permission:
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
@@ -1161,7 +1164,9 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
for dp in db.session.scalars(
select(DatasetPermission).where(DatasetPermission.account_id == user.id)
).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@@ -1175,12 +1180,11 @@ class DatasetService:
@staticmethod
def get_related_apps(dataset_id: str):
return (
db.session.query(AppDatasetJoin)
return db.session.scalars(
select(AppDatasetJoin)
.where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
.order_by(AppDatasetJoin.created_at.desc())
).all()
@staticmethod
def update_dataset_api_status(dataset_id: str, status: bool):