Files
dify/api/repositories/sqlalchemy_execution_extra_content_repository.py
99 40591a7c50 refactor(api): use standalone graphon package (#34209)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-27 21:05:32 +00:00

201 lines
7.8 KiB
Python

from __future__ import annotations
import json
import logging
import re
from collections import defaultdict
from collections.abc import Sequence
from typing import Any
from graphon.nodes.human_input.entities import FormDefinition
from graphon.nodes.human_input.enums import HumanInputFormStatus
from graphon.nodes.human_input.human_input_node import HumanInputNode
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.entities.execution_extra_content import (
ExecutionExtraContentDomainModel,
HumanInputFormDefinition,
HumanInputFormSubmissionData,
)
from core.entities.execution_extra_content import (
HumanInputContent as HumanInputContentDomainModel,
)
from models.execution_extra_content import (
ExecutionExtraContent as ExecutionExtraContentModel,
)
from models.execution_extra_content import (
HumanInputContent as HumanInputContentModel,
)
from models.human_input import HumanInputFormRecipient, RecipientType
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
logger = logging.getLogger(__name__)
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
def _extract_output_field_names(form_content: str) -> list[str]:
if not form_content:
return []
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
def __init__(self, session_maker: sessionmaker[Session]):
self._session_maker = session_maker
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
if not message_ids:
return []
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
message_id: [] for message_id in message_ids
}
stmt = (
select(ExecutionExtraContentModel)
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
.options(selectinload(HumanInputContentModel.form))
.order_by(ExecutionExtraContentModel.created_at.asc())
)
with self._session_maker() as session:
results = session.scalars(stmt).all()
form_ids = {
content.form_id
for content in results
if isinstance(content, HumanInputContentModel) and content.form_id is not None
}
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
recipients = session.scalars(recipient_stmt).all()
for recipient in recipients:
recipients_by_form_id[recipient.form_id].append(recipient)
else:
recipients_by_form_id = {}
for content in results:
message_id = content.message_id
if not message_id or message_id not in grouped_contents:
continue
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
if domain_model is None:
continue
grouped_contents[message_id].append(domain_model)
return [grouped_contents[message_id] for message_id in message_ids]
def _map_model_to_domain(
self,
model: ExecutionExtraContentModel,
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
) -> ExecutionExtraContentDomainModel | None:
if isinstance(model, HumanInputContentModel):
return self._map_human_input_content(model, recipients_by_form_id)
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
return None
def _map_human_input_content(
self,
model: HumanInputContentModel,
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
) -> HumanInputContentDomainModel | None:
form = model.form
if form is None:
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
return None
try:
definition_payload = json.loads(form.form_definition)
if "expiration_time" not in definition_payload:
definition_payload["expiration_time"] = form.expiration_time
form_definition = FormDefinition.model_validate(definition_payload)
except ValueError:
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
return None
node_title = form_definition.node_title or form.node_id
display_in_ui = bool(form_definition.display_in_ui)
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
if not submitted:
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
return HumanInputContentDomainModel(
workflow_run_id=model.workflow_run_id,
submitted=False,
form_definition=HumanInputFormDefinition(
form_id=form.id,
node_id=form.node_id,
node_title=node_title,
form_content=form.rendered_content,
inputs=form_definition.inputs,
actions=form_definition.user_actions,
display_in_ui=display_in_ui,
form_token=form_token,
resolved_default_values=form_definition.default_values,
expiration_time=int(form.expiration_time.timestamp()),
),
)
selected_action_id = form.selected_action_id
if not selected_action_id:
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
return None
action_text = next(
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
selected_action_id,
)
submitted_data: dict[str, Any] = {}
if form.submitted_data:
try:
submitted_data = json.loads(form.submitted_data)
except ValueError:
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
return None
rendered_content = HumanInputNode.render_form_content_with_outputs(
form.rendered_content,
submitted_data,
_extract_output_field_names(form_definition.form_content),
)
return HumanInputContentDomainModel(
workflow_run_id=model.workflow_run_id,
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id=form.node_id,
node_title=node_title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
),
)
@staticmethod
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
console_recipient = next(
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
None,
)
if console_recipient and console_recipient.access_token:
return console_recipient.access_token
web_app_recipient = next(
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
None,
)
if web_app_recipient and web_app_recipient.access_token:
return web_app_recipient.access_token
return None
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]