mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 19:09:21 +08:00
refactor: select in dataset_service (DocumentService class) (#34528)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1400,8 +1400,8 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
|
||||
if document_id:
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
return document
|
||||
else:
|
||||
@@ -1630,7 +1630,7 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_id(document_id: str) -> Document | None:
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
document = db.session.get(Document, document_id)
|
||||
|
||||
return document
|
||||
|
||||
@@ -1695,7 +1695,7 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_document_file_detail(file_id: str):
|
||||
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
|
||||
file_detail = db.session.get(UploadFile, file_id)
|
||||
return file_detail
|
||||
|
||||
@staticmethod
|
||||
@@ -1769,9 +1769,11 @@ class DocumentService:
|
||||
document.name = name
|
||||
db.session.add(document)
|
||||
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
|
||||
db.session.query(UploadFile).where(
|
||||
UploadFile.id == document.data_source_info_dict["upload_file_id"]
|
||||
).update({UploadFile.name: name})
|
||||
db.session.execute(
|
||||
update(UploadFile)
|
||||
.where(UploadFile.id == document.data_source_info_dict["upload_file_id"])
|
||||
.values(name=name)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@@ -1858,8 +1860,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = (
|
||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1)
|
||||
)
|
||||
if document:
|
||||
return document.position + 1
|
||||
@@ -2016,28 +2018,28 @@ class DocumentService:
|
||||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
)
|
||||
.all()
|
||||
files = list(
|
||||
db.session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if len(files) != len(set(upload_file_list)):
|
||||
raise FileNotExistsError("One or more files not found.")
|
||||
|
||||
file_names = [file.name for file in files]
|
||||
db_documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
.all()
|
||||
db_documents = list(
|
||||
db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
documents_map = {document.name: document for document in db_documents}
|
||||
for file in files:
|
||||
@@ -2083,15 +2085,15 @@ class DocumentService:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
documents = list(
|
||||
db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == DataSourceType.NOTION_IMPORT,
|
||||
Document.enabled == True,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if documents:
|
||||
for document in documents:
|
||||
@@ -2522,14 +2524,15 @@ class DocumentService:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
documents_count = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
return documents_count
|
||||
|
||||
@@ -2579,10 +2582,10 @@ class DocumentService:
|
||||
raise ValueError("No file info list found.")
|
||||
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
@@ -2599,8 +2602,8 @@ class DocumentService:
|
||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
@@ -2609,7 +2612,7 @@ class DocumentService:
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
@@ -2654,8 +2657,10 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
# update document segment
|
||||
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.values(status=SegmentStatus.RE_SEGMENT)
|
||||
)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
|
||||
Reference in New Issue
Block a user