Files
dify/api/services/human_input_delivery_test_service.py
99 40591a7c50 refactor(api): use standalone graphon package (#34209)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-27 21:05:32 +00:00

253 lines
8.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Protocol
from graphon.runtime import VariablePool
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.workflow.human_input_compat import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
ExternalRecipient,
MemberRecipient,
)
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_template_renderer import render_email_template
from models import Account, TenantAccountJoin
from services.feature_service import FeatureService
class DeliveryTestStatus(StrEnum):
OK = "ok"
FAILED = "failed"
@dataclass(frozen=True)
class DeliveryTestEmailRecipient:
email: str
form_token: str
@dataclass(frozen=True)
class DeliveryTestContext:
tenant_id: str
app_id: str
node_id: str
node_title: str | None
rendered_content: str
template_vars: dict[str, str] = field(default_factory=dict)
recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list)
variable_pool: VariablePool | None = None
@dataclass(frozen=True)
class DeliveryTestResult:
status: DeliveryTestStatus
delivered_to: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
class DeliveryTestError(Exception):
pass
class DeliveryTestUnsupportedError(DeliveryTestError):
pass
def _build_form_link(token: str | None) -> str | None:
if not token:
return None
base_url = dify_config.APP_WEB_URL
if not base_url:
return None
return f"{base_url.rstrip('/')}/form/{token}"
class DeliveryTestHandler(Protocol):
def supports(self, method: DeliveryChannelConfig) -> bool: ...
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult: ...
class DeliveryTestRegistry:
def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None:
self._handlers = list(handlers or [])
def register(self, handler: DeliveryTestHandler) -> None:
self._handlers.append(handler)
def dispatch(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
for handler in self._handlers:
if handler.supports(method):
return handler.send_test(context=context, method=method)
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
@classmethod
def default(cls) -> DeliveryTestRegistry:
return cls([EmailDeliveryTestHandler()])
class HumanInputDeliveryTestService:
def __init__(self, registry: DeliveryTestRegistry | None = None) -> None:
self._registry = registry or DeliveryTestRegistry.default()
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
return self._registry.dispatch(context=context, method=method)
class EmailDeliveryTestHandler:
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
if session_factory is None:
session_factory = sessionmaker(bind=db.engine)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def supports(self, method: DeliveryChannelConfig) -> bool:
return isinstance(method, EmailDeliveryMethod)
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
if not isinstance(method, EmailDeliveryMethod):
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
features = FeatureService.get_features(context.tenant_id)
if not features.human_input_email_delivery_enabled:
raise DeliveryTestError("Email delivery is not available for current plan.")
if not mail.is_inited():
raise DeliveryTestError("Mail client is not initialized.")
recipients = self._resolve_recipients(
tenant_id=context.tenant_id,
method=method,
)
if not recipients:
raise DeliveryTestError("No recipients configured for delivery method.")
delivered: list[str] = []
for recipient_email in recipients:
substitutions = self._build_substitutions(
context=context,
recipient_email=recipient_email,
)
subject_template = render_email_template(method.config.subject, substitutions)
subject = EmailDeliveryConfig.sanitize_subject(subject_template)
templated_body = EmailDeliveryConfig.render_body_template(
body=method.config.body,
url=substitutions.get("form_link"),
variable_pool=context.variable_pool,
)
body = render_email_template(templated_body, substitutions)
body = EmailDeliveryConfig.render_markdown_body(body)
mail.send(
to=recipient_email,
subject=subject,
html=body,
)
delivered.append(recipient_email)
return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered)
def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]:
recipients = method.config.recipients
emails: list[str] = []
bound_reference_ids: list[str] = []
for recipient in recipients.items:
if isinstance(recipient, MemberRecipient):
bound_reference_ids.append(recipient.reference_id)
elif isinstance(recipient, ExternalRecipient):
if recipient.email:
emails.append(recipient.email)
if recipients.include_bound_group:
bound_reference_ids = []
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None)
emails.extend(member_emails.values())
elif bound_reference_ids:
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids)
for user_id in bound_reference_ids:
email = member_emails.get(user_id)
if email:
emails.append(email)
return list(dict.fromkeys([email for email in emails if email]))
def _query_workspace_member_emails(
self,
*,
tenant_id: str,
user_ids: list[str] | None,
) -> dict[str, str]:
if user_ids is None:
unique_ids = None
else:
unique_ids = {user_id for user_id in user_ids if user_id}
if not unique_ids:
return {}
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == tenant_id)
)
if unique_ids is not None:
stmt = stmt.where(Account.id.in_(unique_ids))
with self._session_factory() as session:
rows = session.execute(stmt).tuples().all()
return dict(rows)
@staticmethod
def _build_substitutions(
*,
context: DeliveryTestContext,
recipient_email: str,
) -> dict[str, str]:
raw_values: dict[str, str | None] = {
"form_id": "",
"node_title": context.node_title,
"workflow_run_id": "",
"form_token": "",
"form_link": "",
"form_content": context.rendered_content,
"recipient_email": recipient_email,
}
substitutions = {key: value or "" for key, value in raw_values.items()}
if context.template_vars:
substitutions.update({key: value for key, value in context.template_vars.items() if value is not None})
token = next(
(recipient.form_token for recipient in context.recipients if recipient.email == recipient_email),
None,
)
if token:
substitutions["form_token"] = token
link = _build_form_link(token)
substitutions["form_link"] = link if link is not None else f"/form/{token}"
return substitutions