mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 09:49:25 +08:00
refactor: select in rag_pipeline (#34495)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -156,27 +156,27 @@ class RagPipelineService:
|
||||
:param template_id: template id
|
||||
:param template_info: template info
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
# check template name is exist
|
||||
template_name = template_info.name
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.id != template_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
@@ -192,13 +192,13 @@ class RagPipelineService:
|
||||
"""
|
||||
Delete customized pipeline template.
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
@@ -210,14 +210,14 @@ class RagPipelineService:
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by rag pipeline
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
@@ -232,28 +232,28 @@ class RagPipelineService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == pipeline.workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""Fetch a published workflow snapshot by ID for restore operations."""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if workflow and workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("source workflow must be published")
|
||||
@@ -974,7 +974,7 @@ class RagPipelineService:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
|
||||
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
|
||||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
document = db.session.get(Document, document_id.value)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
@@ -1178,12 +1178,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
@@ -1194,21 +1194,21 @@ class RagPipelineService:
|
||||
# check template name is exist
|
||||
template_name = args.get("name")
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineCustomizedTemplate.position))
|
||||
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(PipelineCustomizedTemplate.position)).where(
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
|
||||
)
|
||||
)
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@@ -1239,13 +1239,14 @@ class RagPipelineService:
|
||||
|
||||
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
db.session.scalar(
|
||||
select(func.count(Workflow.id)).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
def get_node_last_run(
|
||||
@@ -1353,11 +1354,11 @@ class RagPipelineService:
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict:
|
||||
# Query active recommended plugins
|
||||
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
if type and type != "all":
|
||||
query = query.where(PipelineRecommendedPlugin.type == type)
|
||||
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
|
||||
|
||||
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
|
||||
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
|
||||
|
||||
if not pipeline_recommended_plugins:
|
||||
return {
|
||||
@@ -1396,14 +1397,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_execution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
document_pipeline_execution_log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
|
||||
)
|
||||
if not document_pipeline_execution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
@@ -1432,23 +1431,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
@@ -1530,23 +1529,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
Reference in New Issue
Block a user