refactor(api): Query API to select function_1 (#33565)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-03-17 15:29:16 +01:00
committed by GitHub
parent 076b297b18
commit 7757bb5089
11 changed files with 196 additions and 258 deletions

View File

@@ -8,6 +8,7 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, TypedDict, cast
@@ -145,30 +146,25 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
return None
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
@property
def index_struct_dict(self):
@@ -195,64 +191,66 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
return db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
.limit(1)
)
@property
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
db.session.scalar(
select(func.count(AppDatasetJoin.id)).where(
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
)
)
or 0
)
@property
def document_count(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
)
.scalar()
or 0
)
@property
def word_count(self):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.where(Document.dataset_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
)
@property
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
if document:
return document.doc_form
return None
@@ -270,8 +268,8 @@ class Dataset(Base):
@property
def tags(self):
tags = (
db.session.query(Tag)
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@@ -279,8 +277,7 @@ class Dataset(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "knowledge",
)
.all()
)
).all()
return tags or []
@@ -288,8 +285,8 @@ class Dataset(Base):
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
)
if not external_knowledge_binding:
return None
@@ -310,7 +307,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
if pipeline:
return pipeline.is_published
return False
@@ -521,10 +518,8 @@ class Document(Base):
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
file_detail = db.session.scalar(
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
)
if file_detail:
return {
@@ -557,24 +552,23 @@ class Document(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def segment_count(self):
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
return (
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
)
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.where(DocumentSegment.document_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
)
@property
def uploader(self):
user = db.session.query(Account).where(Account.id == self.created_by).first()
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
return user.name if user else None
@property
@@ -588,14 +582,13 @@ class Document(Base):
@property
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
document_metadatas = db.session.scalars(
select(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
).all()
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: DocMetadataDetailItem = {
@@ -826,7 +819,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> list[Any]:
def child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@@ -835,16 +828,13 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
def get_child_chunks(self) -> list[Any]:
def get_child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@@ -853,12 +843,9 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
@@ -1007,15 +994,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
return db.session.query(Document).where(Document.id == self.document_id).first()
return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def segment(self):
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
class AppDatasetJoin(TypeBase):
@@ -1076,7 +1063,7 @@ class DatasetQuery(TypeBase):
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
if file_info:
query["file_info"] = {
"id": file_info.id,
@@ -1141,7 +1128,7 @@ class DatasetKeywordTable(TypeBase):
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
if not dataset:
return None
if self.data_source_type == "database":
@@ -1535,7 +1522,7 @@ class PipelineCustomizedTemplate(TypeBase):
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
return ""
@@ -1570,7 +1557,7 @@ class Pipeline(TypeBase):
)
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
class DocumentPipelineExecutionLog(TypeBase):