mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 09:49:25 +08:00
refactor: select in dataset_service (DatasetService class) (#34525)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy import exists, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -114,9 +114,11 @@ class DatasetService:
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
dataset_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
)
|
||||
dataset_permission = db.session.scalars(
|
||||
select(DatasetPermission).where(
|
||||
DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id
|
||||
)
|
||||
).all()
|
||||
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
||||
|
||||
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
||||
@@ -182,13 +184,12 @@ class DatasetService:
|
||||
@staticmethod
|
||||
def get_process_rules(dataset_id):
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.execute(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
@@ -225,7 +226,7 @@ class DatasetService:
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@@ -300,17 +301,17 @@ class DatasetService:
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
@@ -344,7 +345,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset: Dataset | None = db.session.get(Dataset, dataset_id)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@@ -466,14 +467,14 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return dataset is not None
|
||||
|
||||
@@ -596,7 +597,7 @@ class DatasetService:
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
|
||||
# Update dataset in database
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data))
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
@@ -631,7 +632,7 @@ class DatasetService:
|
||||
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, dataset.pipeline_id)
|
||||
if not pipeline:
|
||||
return
|
||||
|
||||
@@ -1138,8 +1139,10 @@ class DatasetService:
|
||||
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
# For partial team permission, user needs explicit permission or be the creator
|
||||
if dataset.created_by != user.id:
|
||||
user_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
user_permission = db.session.scalar(
|
||||
select(DatasetPermission)
|
||||
.where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not user_permission:
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
@@ -1161,7 +1164,9 @@ class DatasetService:
|
||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
if not any(
|
||||
dp.dataset_id == dataset.id
|
||||
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
|
||||
for dp in db.session.scalars(
|
||||
select(DatasetPermission).where(DatasetPermission.account_id == user.id)
|
||||
).all()
|
||||
):
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@@ -1175,12 +1180,11 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
return (
|
||||
db.session.query(AppDatasetJoin)
|
||||
return db.session.scalars(
|
||||
select(AppDatasetJoin)
|
||||
.where(AppDatasetJoin.dataset_id == dataset_id)
|
||||
.order_by(db.desc(AppDatasetJoin.created_at))
|
||||
.all()
|
||||
)
|
||||
.order_by(AppDatasetJoin.created_at.desc())
|
||||
).all()
|
||||
|
||||
@staticmethod
|
||||
def update_dataset_api_status(dataset_id: str, status: bool):
|
||||
|
||||
Reference in New Issue
Block a user