mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 09:49:25 +08:00
feat(api): Human Input Node (backend part) (#31646)
The backend part of the human in the loop (HITL) feature and relevant architecture / workflow engine changes. Signed-off-by: yihong0618 <zouzou0208@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: 盐粒 Yanli <yanli@dify.ai> Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
|
||||
@@ -61,7 +61,14 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||
# NOTE(QuantumGhost): this is an issue in
|
||||
# upstream code. If Sharded PubSub is used with Cluster, the
|
||||
# `ClusterPubSub.get_sharded_message` will return `None` regardless of
|
||||
# message['type'].
|
||||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
||||
49
api/libs/email_template_renderer.py
Normal file
49
api/libs/email_template_renderer.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Email template rendering helpers with configurable safety modes.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import render_template_string
|
||||
from jinja2.runtime import Context
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
||||
from configs import dify_config
|
||||
from configs.feature import TemplateMode
|
||||
|
||||
|
||||
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
||||
"""Sandboxed environment with execution timeout."""
|
||||
|
||||
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
|
||||
self._deadline = time.time() + timeout if timeout else None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if self._deadline is not None and time.time() > self._deadline:
|
||||
raise TimeoutError("Template rendering timeout")
|
||||
return super().call(context, obj, *args, **kwargs)
|
||||
|
||||
|
||||
def render_email_template(template: str, substitutions: Mapping[str, str]) -> str:
|
||||
"""
|
||||
Render email template content according to the configured template mode.
|
||||
|
||||
In unsafe mode, Jinja expressions are evaluated directly.
|
||||
In sandbox mode, a sandboxed environment with timeout is used.
|
||||
In disabled mode, the template is returned without rendering.
|
||||
"""
|
||||
mode = dify_config.MAIL_TEMPLATING_MODE
|
||||
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
|
||||
|
||||
if mode == TemplateMode.UNSAFE:
|
||||
return render_template_string(template, **substitutions)
|
||||
if mode == TemplateMode.SANDBOX:
|
||||
env = SandboxedEnvironment(timeout=timeout)
|
||||
tmpl = env.from_string(template)
|
||||
return tmpl.render(substitutions)
|
||||
if mode == TemplateMode.DISABLED:
|
||||
return template
|
||||
raise ValueError(f"Unsupported mail templating mode: {mode}")
|
||||
@@ -1,12 +1,15 @@
|
||||
import contextvars
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import Flask, g
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import Account, EndUser
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preserve_flask_contexts(
|
||||
@@ -64,3 +67,7 @@ def preserve_flask_contexts(
|
||||
finally:
|
||||
# Any cleanup can be added here if needed
|
||||
pass
|
||||
|
||||
|
||||
def set_login_user(user: "Account | EndUser"):
|
||||
g._login_user = user
|
||||
|
||||
@@ -7,10 +7,10 @@ import struct
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
@@ -126,6 +126,13 @@ class TimestampField(fields.Raw):
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class OptionalTimestampField(fields.Raw):
|
||||
def format(self, value) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def email(email):
|
||||
# Define a regex pattern for email addresses
|
||||
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
||||
@@ -237,6 +244,26 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
||||
|
||||
|
||||
def generate_string(n):
|
||||
"""
|
||||
Generates a cryptographically secure random string of the specified length.
|
||||
|
||||
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
|
||||
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
|
||||
|
||||
Each character in the generated string provides approximately 5.95 bits of entropy
|
||||
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
|
||||
length of the string (`n`) should be at least 22 characters.
|
||||
|
||||
Args:
|
||||
n (int): The length of the random string to generate. For secure usage,
|
||||
`n` should be 22 or greater.
|
||||
|
||||
Returns:
|
||||
str: A random string of length `n` composed of ASCII letters and digits.
|
||||
|
||||
Note:
|
||||
This function is suitable for generating credentials or other secure tokens.
|
||||
"""
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
for _ in range(n):
|
||||
@@ -405,11 +432,35 @@ class TokenManager:
|
||||
return f"{token_type}:account:{account_id}"
|
||||
|
||||
|
||||
class _RateLimiterRedisClient(Protocol):
|
||||
def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ...
|
||||
|
||||
def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ...
|
||||
|
||||
def zcard(self, name: str | bytes) -> int: ...
|
||||
|
||||
def expire(self, name: str | bytes, time: int) -> bool: ...
|
||||
|
||||
|
||||
def _default_rate_limit_member_factory() -> str:
|
||||
current_time = int(time.time())
|
||||
return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, prefix: str, max_attempts: int, time_window: int):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
max_attempts: int,
|
||||
time_window: int,
|
||||
member_factory: Callable[[], str] = _default_rate_limit_member_factory,
|
||||
redis_client: _RateLimiterRedisClient = redis_client,
|
||||
):
|
||||
self.prefix = prefix
|
||||
self.max_attempts = max_attempts
|
||||
self.time_window = time_window
|
||||
self._member_factory = member_factory
|
||||
self._redis_client = redis_client
|
||||
|
||||
def _get_key(self, email: str) -> str:
|
||||
return f"{self.prefix}:{email}"
|
||||
@@ -419,8 +470,8 @@ class RateLimiter:
|
||||
current_time = int(time.time())
|
||||
window_start_time = current_time - self.time_window
|
||||
|
||||
redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = redis_client.zcard(key)
|
||||
self._redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = self._redis_client.zcard(key)
|
||||
|
||||
if attempts and int(attempts) >= self.max_attempts:
|
||||
return True
|
||||
@@ -428,7 +479,8 @@ class RateLimiter:
|
||||
|
||||
def increment_rate_limit(self, email: str):
|
||||
key = self._get_key(email)
|
||||
member = self._member_factory()
|
||||
current_time = int(time.time())
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.expire(key, self.time_window * 2)
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
Reference in New Issue
Block a user