mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
constants
|
||||
context
|
||||
dify_graph
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
factories
|
||||
libs
|
||||
models
|
||||
tasks
|
||||
services
|
||||
@@ -33,29 +37,19 @@ ignore_imports =
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
type = forbidden
|
||||
source_modules =
|
||||
dify_graph
|
||||
forbidden_modules =
|
||||
extensions.ext_database
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
dify_graph
|
||||
forbidden_modules =
|
||||
constants
|
||||
configs
|
||||
context
|
||||
controllers
|
||||
extensions
|
||||
factories
|
||||
libs
|
||||
models
|
||||
services
|
||||
tasks
|
||||
@@ -88,46 +82,14 @@ forbidden_modules =
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
dify_graph.nodes.llm.node -> core.tools.signature
|
||||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.__base.large_language_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> configs
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
|
||||
[importlinter:contract:workflow-third-party-imports]
|
||||
name = Workflow Third-Party Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
dify_graph
|
||||
forbidden_modules =
|
||||
sqlalchemy
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
|
||||
@@ -1,74 +1,36 @@
|
||||
"""
|
||||
Core Context - Framework-agnostic context management.
|
||||
Application-layer context adapters.
|
||||
|
||||
This module provides context management that is independent of any specific
|
||||
web framework. Framework-specific implementations register their context
|
||||
capture functions at application initialization time.
|
||||
|
||||
This ensures the workflow layer remains completely decoupled from Flask
|
||||
or any other web framework.
|
||||
Concrete execution-context implementations live here so `dify_graph` only
|
||||
depends on injected context managers rather than framework state capture.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Callable
|
||||
|
||||
from dify_graph.context.execution_context import (
|
||||
from context.execution_context import (
|
||||
AppContext,
|
||||
ContextProviderNotFoundError,
|
||||
ExecutionContext,
|
||||
ExecutionContextBuilder,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
capture_current_context,
|
||||
read_context,
|
||||
register_context,
|
||||
register_context_capturer,
|
||||
reset_context_provider,
|
||||
)
|
||||
|
||||
# Global capturer function - set by framework-specific modules
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
|
||||
|
||||
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
|
||||
"""
|
||||
Register a context capture function.
|
||||
|
||||
This should be called by framework-specific modules (e.g., Flask)
|
||||
during application initialization.
|
||||
|
||||
Args:
|
||||
capturer: Function that captures current context and returns IExecutionContext
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = capturer
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context.
|
||||
|
||||
This function uses the registered context capturer. If no capturer
|
||||
is registered, it returns a minimal context with only contextvars
|
||||
(suitable for non-framework environments like tests or standalone scripts).
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
if _capturer is None:
|
||||
# No framework registered - return minimal context
|
||||
return ExecutionContext(
|
||||
app_context=NullAppContext(),
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
return _capturer()
|
||||
|
||||
|
||||
def reset_context_provider() -> None:
|
||||
"""
|
||||
Reset the context capturer.
|
||||
|
||||
This is primarily useful for testing to ensure a clean state.
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = None
|
||||
|
||||
from context.models import SandboxContext
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ContextProviderNotFoundError",
|
||||
"ExecutionContext",
|
||||
"ExecutionContextBuilder",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"SandboxContext",
|
||||
"capture_current_context",
|
||||
"read_context",
|
||||
"register_context",
|
||||
"register_context_capturer",
|
||||
"reset_context_provider",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
Execution Context - Abstracted context management for workflow execution.
|
||||
Application-layer execution context adapters.
|
||||
|
||||
Concrete context capture lives outside `dify_graph` so the graph package only
|
||||
consumes injected context managers when it needs to preserve thread-local state.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
@@ -16,33 +19,33 @@ class AppContext(ABC):
|
||||
"""
|
||||
Abstract application context interface.
|
||||
|
||||
This abstraction allows workflow execution to work with or without Flask
|
||||
by providing a common interface for application context management.
|
||||
Application adapters can implement this to restore framework-specific state
|
||||
such as Flask app context around worker execution.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name (e.g., 'db', 'cache')."""
|
||||
pass
|
||||
"""Get application extension by name."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def enter(self) -> AbstractContextManager[None]:
|
||||
"""Enter the application context."""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IExecutionContext(Protocol):
|
||||
"""
|
||||
Protocol for execution context.
|
||||
Protocol for enterable execution context objects.
|
||||
|
||||
This protocol defines the interface that all execution contexts must implement,
|
||||
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
|
||||
Concrete implementations may carry extra framework state, but callers only
|
||||
depend on standard context-manager behavior plus optional user metadata.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "IExecutionContext":
|
||||
@@ -62,14 +65,10 @@ class IExecutionContext(Protocol):
|
||||
@final
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for workflow execution in worker threads.
|
||||
Generic execution context used by application-layer adapters.
|
||||
|
||||
This class encapsulates all context needed for workflow execution:
|
||||
- Application context (Flask app or standalone)
|
||||
- Context variables for Python contextvars
|
||||
- User information (optional)
|
||||
|
||||
It is designed to be serializable and passable to worker threads.
|
||||
It restores captured `contextvars` and optionally enters an application
|
||||
context before the worker executes graph logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,14 +77,6 @@ class ExecutionContext:
|
||||
context_vars: contextvars.Context | None = None,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize execution context.
|
||||
|
||||
Args:
|
||||
app_context: Application context (Flask or standalone)
|
||||
context_vars: Python contextvars to preserve
|
||||
user: User object (optional)
|
||||
"""
|
||||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
@@ -98,27 +89,21 @@ class ExecutionContext:
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context | None:
|
||||
"""Get context variables."""
|
||||
"""Get captured context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
"""Get captured user object."""
|
||||
return self._user
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enter this execution context.
|
||||
|
||||
This is a convenience method that creates a context manager.
|
||||
"""
|
||||
# Restore context variables if provided
|
||||
"""Enter this execution context."""
|
||||
if self._context_vars:
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Enter app context if available
|
||||
if self._app_context is not None:
|
||||
with self._app_context.enter():
|
||||
yield
|
||||
@@ -141,18 +126,10 @@ class ExecutionContext:
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
"""
|
||||
Null implementation of AppContext for non-Flask environments.
|
||||
|
||||
This is used when running without Flask (e.g., in tests or standalone mode).
|
||||
Null application context for non-framework environments.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize null app context.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
@@ -165,7 +142,7 @@ class NullAppContext(AppContext):
|
||||
return self._extensions.get(name)
|
||||
|
||||
def set_extension(self, name: str, extension: Any) -> None:
|
||||
"""Set extension by name."""
|
||||
"""Register an extension for tests or standalone execution."""
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
@@ -176,9 +153,7 @@ class NullAppContext(AppContext):
|
||||
|
||||
class ExecutionContextBuilder:
|
||||
"""
|
||||
Builder for creating ExecutionContext instances.
|
||||
|
||||
This provides a fluent API for building execution contexts.
|
||||
Builder for creating `ExecutionContext` instances.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -211,63 +186,42 @@ class ExecutionContextBuilder:
|
||||
|
||||
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
|
||||
# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
|
||||
# Key mapping:
|
||||
# (name, tenant_id) -> provider
|
||||
# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
|
||||
# - tenant_id: tenant identifier string
|
||||
# Value:
|
||||
# provider: Callable[[], BaseModel] returning the typed context value
|
||||
# Type-safety note:
|
||||
# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
|
||||
# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
|
||||
# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
|
||||
# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
|
||||
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ContextProviderNotFoundError(KeyError):
|
||||
"""Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
|
||||
"""Raised when a tenant-scoped context provider is missing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
|
||||
"""Register a single enterable execution context capturer (e.g., Flask)."""
|
||||
"""Register an enterable execution context capturer."""
|
||||
global _capturer
|
||||
_capturer = capturer
|
||||
|
||||
|
||||
def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
|
||||
"""Register a tenant-specific provider for a named context.
|
||||
|
||||
Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
|
||||
Consider adding a typed wrapper for this registration in your feature module.
|
||||
"""
|
||||
"""Register a tenant-specific provider for a named context."""
|
||||
_tenant_context_providers[(name, tenant_id)] = provider
|
||||
|
||||
|
||||
def read_context(name: str, *, tenant_id: str) -> BaseModel:
|
||||
"""
|
||||
Read a context value for a specific tenant.
|
||||
|
||||
Raises KeyError if the provider for (name, tenant_id) is not registered.
|
||||
"""
|
||||
prov = _tenant_context_providers.get((name, tenant_id))
|
||||
if prov is None:
|
||||
"""Read a context value for a specific tenant."""
|
||||
provider = _tenant_context_providers.get((name, tenant_id))
|
||||
if provider is None:
|
||||
raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
|
||||
return prov()
|
||||
return provider()
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context from the calling environment.
|
||||
|
||||
If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
|
||||
context with NullAppContext + copy of current contextvars.
|
||||
If no framework adapter is registered, return a minimal context that only
|
||||
restores `contextvars`.
|
||||
"""
|
||||
if _capturer is None:
|
||||
return ExecutionContext(
|
||||
@@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext:
|
||||
|
||||
|
||||
def reset_context_provider() -> None:
|
||||
"""Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
|
||||
"""Reset the capturer and tenant-scoped providers."""
|
||||
global _capturer
|
||||
_capturer = None
|
||||
_tenant_context_providers.clear()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ContextProviderNotFoundError",
|
||||
"ExecutionContext",
|
||||
"ExecutionContextBuilder",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"capture_current_context",
|
||||
"read_context",
|
||||
"register_context",
|
||||
"register_context_capturer",
|
||||
"reset_context_provider",
|
||||
]
|
||||
@@ -10,11 +10,7 @@ from typing import Any, final
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
from dify_graph.context import register_context_capturer
|
||||
from dify_graph.context.execution_context import (
|
||||
AppContext,
|
||||
IExecutionContext,
|
||||
)
|
||||
from context.execution_context import AppContext, IExecutionContext, register_context_capturer
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
@@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl
|
||||
|
||||
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||
|
||||
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers")
|
||||
)
|
||||
|
||||
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
@@ -88,6 +88,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
@@ -127,6 +128,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
@@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
@@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
|
||||
mappings=files,
|
||||
tenant_id=workflow.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
return file_objs
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.variables.segment_group import SegmentGroup
|
||||
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
@@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@@ -389,13 +391,21 @@ class VariableApi(Resource):
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@@ -12,6 +12,7 @@ from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
@@ -496,6 +497,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
||||
form_tokens_by_form_id = _load_form_tokens_by_form_id(
|
||||
[reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)]
|
||||
)
|
||||
|
||||
# Build response
|
||||
paused_at = pause_entity.paused_at if pause_entity else None
|
||||
@@ -514,7 +518,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"pause_type": {
|
||||
"type": "human_input",
|
||||
"form_id": reason.form_id,
|
||||
"backstage_input_url": _build_backstage_input_url(reason.form_token),
|
||||
"backstage_input_url": _build_backstage_input_url(
|
||||
form_tokens_by_form_id.get(reason.form_id)
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
@@ -332,7 +332,7 @@ class DatasetListApi(Resource):
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
@@ -446,7 +446,7 @@ class DatasetApi(Resource):
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@@ -454,7 +454,7 @@ class DatasetInitApi(Resource):
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=knowledge_config.embedding_model_provider,
|
||||
|
||||
@@ -283,7 +283,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -336,7 +336,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -387,7 +387,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -572,7 +572,7 @@ class ChildChunkAddApi(Resource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
|
||||
@@ -21,7 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
@@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
@@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource):
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
|
||||
if args.config_from == "predefined-model":
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
available_credentials = model_provider_service.get_provider_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=normalized_model_type,
|
||||
model=args.model,
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
||||
@@ -70,22 +70,25 @@ class ToolFileApi(Resource):
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
mime_type = tool_file.mime_type
|
||||
filename = tool_file.filename
|
||||
|
||||
response = Response(
|
||||
stream,
|
||||
mimetype=tool_file.mimetype,
|
||||
mimetype=mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
if tool_file.size > 0:
|
||||
response.headers["Content-Length"] = str(tool_file.size)
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(tool_file.name)
|
||||
if args.as_attachment and filename:
|
||||
encoded_filename = quote(filename)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
|
||||
enforce_download_for_html(
|
||||
response,
|
||||
mime_type=tool_file.mimetype,
|
||||
filename=tool_file.name,
|
||||
mime_type=mime_type,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from core.tools.signature import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file.helpers import verify_plugin_file_signature
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
from ..common.errors import (
|
||||
|
||||
@@ -28,7 +28,7 @@ from core.plugin.entities.request import (
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.file.helpers import get_signed_file_url_for_plugin
|
||||
from core.tools.signature import get_signed_file_url_for_plugin
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import length_prefixed_response
|
||||
from models import Account, Tenant
|
||||
|
||||
@@ -14,7 +14,7 @@ from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
)
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
@@ -140,10 +140,10 @@ class DatasetListApi(DatasetApiResource):
|
||||
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=cid)
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
@@ -259,10 +259,10 @@ class DatasetApi(DatasetApiResource):
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=cid)
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@@ -106,7 +106,7 @@ class SegmentApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -160,7 +160,7 @@ class SegmentApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -266,7 +266,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -361,7 +361,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -46,6 +47,7 @@ from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
@@ -138,6 +140,7 @@ class BaseAgentRunner(AppRunner):
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
user_id=self.user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
assert tool_entity.entity.description
|
||||
@@ -524,7 +527,10 @@ class BaseAgentRunner(AppRunner):
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||
message_files=files,
|
||||
tenant_id=self.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
@@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -21,7 +21,7 @@ class ModelConfigConverter:
|
||||
"""
|
||||
model_config = app_config.model
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
@@ -2,9 +2,8 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
@@ -54,9 +53,12 @@ class ModelConfigManager:
|
||||
if not isinstance(config["model"], dict):
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# Keep provider discovery and provider-backed model listing on the same
|
||||
# request-scoped runtime so caller scope and provider caches stay aligned.
|
||||
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
provider_entities = assembly.model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
@@ -71,8 +73,7 @@ class ModelConfigManager:
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
models = assembly.provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
@@ -34,13 +35,9 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from dify_graph.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
@@ -150,85 +147,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
@@ -460,94 +459,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = conversation is None
|
||||
with self._bind_file_access_scope(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
):
|
||||
is_first_conversation = conversation is None
|
||||
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# get conversation dialogue count
|
||||
# NOTE: dialogue_count should not start from 0,
|
||||
# because during the first conversation, dialogue_count should be 1.
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
|
||||
# get conversation dialogue count
|
||||
# NOTE: dialogue_count should not start from 0,
|
||||
# because during the first conversation, dialogue_count should be 1.
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
# message_ = session.get(Message, message.id)
|
||||
# assert message_ is not None
|
||||
# message = message_
|
||||
# db.session.refresh(workflow)
|
||||
# db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
db.session.close()
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@@ -25,14 +25,19 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
build_system_variables,
|
||||
system_variables_to_mapping,
|
||||
)
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import VariableLoader
|
||||
from dify_graph.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
@@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
system_inputs = build_system_variables(
|
||||
query=self.application_generate_entity.query,
|
||||
files=self.application_generate_entity.files,
|
||||
conversation_id=self.conversation.id,
|
||||
@@ -150,7 +155,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
self.application_generate_entity.inputs = new_inputs
|
||||
self.application_generate_entity.query = new_query
|
||||
system_inputs.query = new_query
|
||||
system_inputs = build_system_variables(
|
||||
system_variables_to_mapping(system_inputs),
|
||||
query=new_query,
|
||||
)
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
@@ -166,14 +174,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# Create a variable pool.
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=new_inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(
|
||||
variable_pool,
|
||||
build_bootstrap_variables(
|
||||
system_variables=system_inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
),
|
||||
)
|
||||
root_node_id = get_default_root_node_id(self._workflow.graph_dict)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs)
|
||||
|
||||
# init graph
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
@@ -185,6 +196,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
@@ -14,6 +14,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
@@ -65,14 +66,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
@@ -117,7 +118,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_system_variables = SystemVariable(
|
||||
self._workflow_system_variables = build_system_variables(
|
||||
query=message.query,
|
||||
files=application_generate_entity.files,
|
||||
conversation_id=conversation.id,
|
||||
@@ -741,8 +742,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
form = form_repository.get_form(node_id)
|
||||
if form is None:
|
||||
return None
|
||||
return form.id
|
||||
@@ -933,21 +935,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatorUserRole.ACCOUNT
|
||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatorUserRole.END_USER,
|
||||
created_by=message.from_account_id or message.from_end_user_id or "",
|
||||
message_files: list[MessageFile] = []
|
||||
for file in self._recorded_files:
|
||||
reference = file.get("reference") or file.get("related_id")
|
||||
message_files.append(
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None),
|
||||
created_by_role=CreatorUserRole.ACCOUNT
|
||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatorUserRole.END_USER,
|
||||
created_by=message.from_account_id or message.from_end_user_id or "",
|
||||
)
|
||||
)
|
||||
for file in self._recorded_files
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
@@ -1003,13 +1007,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
return message
|
||||
|
||||
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
saver = self._draft_var_saver_factory(
|
||||
session=session,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
saver = self._draft_var_saver_factory(
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
|
||||
@@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
override_model_config_dict or app_model_config.to_dict()
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
# convert to app config
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.repositories.draft_variable_repository import (
|
||||
from core.app.apps.draft_variable_saver import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.variables.input_entities import VariableEntityType
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
@@ -21,7 +24,66 @@ if TYPE_CHECKING:
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
@final
|
||||
class _DebuggerDraftVariableSaver:
|
||||
"""Adapter that binds SQLAlchemy session setup outside the saver port."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
account: Account,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> None:
|
||||
self._account = account
|
||||
self._app_id = app_id
|
||||
self._node_id = node_id
|
||||
self._node_type = node_type
|
||||
self._node_execution_id = node_execution_id
|
||||
self._enclosing_node_id = enclosing_node_id
|
||||
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
DraftVariableSaverImpl(
|
||||
session=session,
|
||||
app_id=self._app_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self._node_type,
|
||||
node_execution_id=self._node_execution_id,
|
||||
enclosing_node_id=self._enclosing_node_id,
|
||||
user=self._account,
|
||||
).save(process_data, outputs)
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
_file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController()
|
||||
|
||||
@staticmethod
|
||||
def _bind_file_access_scope(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user: Account | EndUser,
|
||||
invoke_from: InvokeFrom,
|
||||
) -> AbstractContextManager[None]:
|
||||
"""Bind request-scoped file ownership markers for downstream file lookups."""
|
||||
|
||||
user_id = getattr(user, "id", None)
|
||||
if not isinstance(user_id, str) or not user_id:
|
||||
return nullcontext()
|
||||
|
||||
user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER
|
||||
return bind_file_access_scope(
|
||||
FileAccessScope(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
def _prepare_user_inputs(
|
||||
self,
|
||||
*,
|
||||
@@ -50,6 +112,7 @@ class BaseAppGenerator:
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
strict_type_validation=strict_type_validation,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||
@@ -64,6 +127,7 @@ class BaseAppGenerator:
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
@@ -226,32 +290,30 @@ class BaseAppGenerator:
|
||||
assert isinstance(account, Account)
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
return DraftVariableSaverImpl(
|
||||
session=session,
|
||||
return _DebuggerDraftVariableSaver(
|
||||
account=account,
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
user=account,
|
||||
)
|
||||
else:
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
_ = app_id, node_id, node_type, node_execution_id, enclosing_node_id
|
||||
return NoopDraftVariableSaver()
|
||||
|
||||
return draft_var_saver_factory
|
||||
|
||||
@@ -61,27 +61,30 @@ class AppQueueManager(ABC):
|
||||
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
|
||||
start_time = time.time()
|
||||
last_ping_time: int | float = 0
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
finally:
|
||||
self._graph_runtime_state = None # Release reference once consumers finish or close the generator.
|
||||
|
||||
def stop_listen(self):
|
||||
"""
|
||||
@@ -90,7 +93,6 @@ class AppQueueManager(ABC):
|
||||
"""
|
||||
self._clear_task_belong_cache()
|
||||
self._q.put(None)
|
||||
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
|
||||
|
||||
def _clear_task_belong_cache(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
override_model_config_dict or app_model_config.to_dict()
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
# convert to app config
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=streaming,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner):
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -30,10 +31,10 @@ class GraphRuntimeStateSupport:
|
||||
return self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
|
||||
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
|
||||
system_variables = graph_runtime_state.variable_pool.system_variables
|
||||
if not system_variables or not system_variables.workflow_execution_id:
|
||||
workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
if not workflow_run_id:
|
||||
raise ValueError("workflow_execution_id missing from runtime state")
|
||||
return str(system_variables.workflow_execution_id)
|
||||
return workflow_run_id
|
||||
|
||||
def _resolve_graph_runtime_state(
|
||||
self,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
@@ -50,20 +51,21 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from dify_graph.variables.variables import Variable
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -111,11 +113,11 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
system_variables: SystemVariable,
|
||||
system_variables: Sequence[Variable],
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
self._system_variables = system_variables
|
||||
self._system_variables = system_variables_to_mapping(system_variables)
|
||||
self._workflow_inputs = self._prepare_workflow_inputs()
|
||||
|
||||
# Disable truncation for SERVICE_API calls to keep backward compatibility.
|
||||
@@ -133,7 +135,7 @@ class WorkflowResponseConverter:
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = dict(self._application_generate_entity.inputs)
|
||||
for field_name, value in self._system_variables.to_dict().items():
|
||||
for field_name, value in self._system_variables.items():
|
||||
# TODO(@future-refactor): store system variables separately from user inputs so we don't
|
||||
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID:
|
||||
@@ -318,13 +320,23 @@ class WorkflowResponseConverter:
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
form_token_by_form_id: dict[str, str] = {}
|
||||
if human_input_form_ids:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
stmt = select(
|
||||
HumanInputForm.id,
|
||||
HumanInputForm.expiration_time,
|
||||
HumanInputForm.form_definition,
|
||||
).where(HumanInputForm.id.in_(human_input_form_ids))
|
||||
with Session(bind=db.engine) as session:
|
||||
for form_id, expiration_time in session.execute(stmt):
|
||||
for form_id, expiration_time, form_definition in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
||||
try:
|
||||
definition_payload = json.loads(form_definition) if form_definition else {}
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
@@ -344,8 +356,8 @@ class WorkflowResponseConverter:
|
||||
form_content=reason.form_content,
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=reason.display_in_ui,
|
||||
form_token=reason.form_token,
|
||||
display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False),
|
||||
form_token=form_token_by_form_id.get(reason.form_id),
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=int(expiration_time.timestamp()),
|
||||
),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
override_model_config_dict or app_model_config.to_dict()
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
@@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
# parse files
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread.start()
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner):
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
@@ -4,31 +4,30 @@ import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DraftVariableSaver(Protocol):
|
||||
@abc.abstractmethod
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
"""Persist node draft variables for a completed execution."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DraftVariableSaverFactory(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
pass
|
||||
"""Build a saver bound to a concrete node execution."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
return None
|
||||
@@ -28,6 +28,7 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
@@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
upload_file_id=resolve_file_record_id(file.reference),
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
@@ -34,11 +35,12 @@ from core.datasource.entities.datasource_entities import (
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.factory import (
|
||||
DifyCoreRepositoryFactory,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
@@ -12,16 +12,16 @@ from core.app.entities.app_invoke_entities import (
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import VariableLoader
|
||||
from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from extensions.ext_database import db
|
||||
@@ -112,7 +112,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
system_inputs = build_system_variables(
|
||||
files=files,
|
||||
user_id=user_id,
|
||||
app_id=app_config.app_id,
|
||||
@@ -142,19 +142,25 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(
|
||||
variable_pool,
|
||||
build_bootstrap_variables(
|
||||
system_variables=system_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
),
|
||||
)
|
||||
root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id(
|
||||
workflow.graph_dict
|
||||
)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
start_node_id=root_node_id,
|
||||
workflow=workflow,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -17,6 +17,7 @@ from configs import dify_config
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
@@ -30,11 +31,9 @@ from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
@@ -129,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
access_controller=self._file_access_controller,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create repositories
|
||||
#
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
@@ -292,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
"""
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
with self._bind_file_access_scope(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
):
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.close()
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.close()
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
|
||||
@@ -8,14 +8,15 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import VariableLoader
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
@@ -96,7 +97,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
system_inputs = build_system_variables(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
@@ -104,12 +105,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(
|
||||
variable_pool,
|
||||
build_bootstrap_variables(
|
||||
system_variables=system_inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
),
|
||||
)
|
||||
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph = self._init_graph(
|
||||
@@ -120,7 +125,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=self._root_node_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
|
||||
@@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@@ -55,11 +56,10 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._invoke_from = queue_manager.invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = SystemVariable(
|
||||
self._workflow_system_variables = build_system_variables(
|
||||
files=application_generate_entity.files,
|
||||
user_id=user_session_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
return response
|
||||
|
||||
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
saver = self._draft_var_saver_factory(
|
||||
session=session,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
saver = self._draft_var_saver_factory(
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
|
||||
@@ -34,7 +34,16 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
default_system_variables,
|
||||
get_node_creation_preload_selectors,
|
||||
inject_default_system_variable_mappings,
|
||||
preload_node_creation_variables,
|
||||
)
|
||||
from core.workflow.variable_pool_initializer import add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
@@ -68,7 +77,6 @@ from dify_graph.graph_events import (
|
||||
)
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from models.workflow import Workflow
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
@@ -173,14 +181,15 @@ class WorkflowBasedAppRunner:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(
|
||||
variable_pool,
|
||||
build_bootstrap_variables(
|
||||
system_variables=default_system_variables(),
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
@@ -272,6 +281,8 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
workflow_id=workflow.id,
|
||||
@@ -291,26 +302,15 @@ class WorkflowBasedAppRunner:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
for node in typed_node_configs:
|
||||
if node["id"] == node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
@@ -319,12 +319,31 @@ class WorkflowBasedAppRunner:
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
|
||||
preload_node_creation_variables(
|
||||
variable_loader=self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
selectors=[
|
||||
selector
|
||||
for node_config in typed_node_configs
|
||||
for selector in get_node_creation_preload_selectors(
|
||||
node_type=node_config["data"].type,
|
||||
node_data=node_config["data"],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
variable_mapping = inject_default_system_variable_mappings(
|
||||
node_id=target_node_config["id"],
|
||||
node_type=node_type,
|
||||
node_data=target_node_config["data"],
|
||||
variable_mapping=variable_mapping,
|
||||
)
|
||||
|
||||
load_into_variable_pool(
|
||||
variable_loader=self._variable_loader,
|
||||
@@ -340,6 +359,14 @@ class WorkflowBasedAppRunner:
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
# init graph after constructor-time context has been loaded
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
@staticmethod
|
||||
@@ -408,7 +435,11 @@ class WorkflowBasedAppRunner:
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=inputs,
|
||||
outputs=node_run_result.outputs,
|
||||
)
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
@@ -448,7 +479,11 @@ class WorkflowBasedAppRunner:
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=inputs,
|
||||
outputs=node_run_result.outputs,
|
||||
)
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
@@ -466,6 +501,11 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=event.node_run_result.inputs,
|
||||
outputs=event.node_run_result.outputs,
|
||||
)
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
@@ -475,7 +515,7 @@ class WorkflowBasedAppRunner:
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
outputs=outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
@@ -483,6 +523,11 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=event.node_type,
|
||||
inputs=event.node_run_result.inputs,
|
||||
outputs=event.node_run_result.outputs,
|
||||
)
|
||||
self._publish_event(
|
||||
QueueNodeExceptionEvent(
|
||||
node_execution_id=event.id,
|
||||
@@ -492,7 +537,7 @@ class WorkflowBasedAppRunner:
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
outputs=outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
@@ -15,6 +14,9 @@ if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
11
api/core/app/file_access/__init__.py
Normal file
11
api/core/app/file_access/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .controller import DatabaseFileAccessController
|
||||
from .protocols import FileAccessControllerProtocol
|
||||
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
|
||||
|
||||
__all__ = [
|
||||
"DatabaseFileAccessController",
|
||||
"FileAccessControllerProtocol",
|
||||
"FileAccessScope",
|
||||
"bind_file_access_scope",
|
||||
"get_current_file_access_scope",
|
||||
]
|
||||
103
api/core/app/file_access/controller.py
Normal file
103
api/core/app/file_access/controller.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
from .protocols import FileAccessControllerProtocol
|
||||
from .scope import FileAccessScope, get_current_file_access_scope
|
||||
|
||||
|
||||
class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
"""Workflow-layer authorization helper for database-backed file lookups.
|
||||
|
||||
Tenant scoping remains mandatory. When the current execution belongs to an
|
||||
end user, the lookup is additionally constrained to that end user's file
|
||||
ownership markers.
|
||||
"""
|
||||
|
||||
_scope_getter: Callable[[], FileAccessScope | None]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope,
|
||||
) -> None:
|
||||
self._scope_getter = scope_getter
|
||||
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
return self._scope_getter()
|
||||
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[UploadFile]]:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return stmt
|
||||
|
||||
scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id)
|
||||
if not resolved_scope.requires_user_ownership:
|
||||
return scoped_stmt
|
||||
|
||||
return scoped_stmt.where(
|
||||
UploadFile.created_by_role == CreatorUserRole.END_USER,
|
||||
UploadFile.created_by == resolved_scope.user_id,
|
||||
)
|
||||
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[ToolFile]]:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return stmt
|
||||
|
||||
scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id)
|
||||
if not resolved_scope.requires_user_ownership:
|
||||
return scoped_stmt
|
||||
|
||||
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
|
||||
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> UploadFile | None:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return session.get(UploadFile, file_id)
|
||||
|
||||
stmt = self.apply_upload_file_filters(
|
||||
select(UploadFile).where(UploadFile.id == file_id),
|
||||
scope=resolved_scope,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> ToolFile | None:
|
||||
resolved_scope = scope or self.current_scope()
|
||||
if resolved_scope is None:
|
||||
return session.get(ToolFile, file_id)
|
||||
|
||||
stmt = self.apply_tool_file_filters(
|
||||
select(ToolFile).where(ToolFile.id == file_id),
|
||||
scope=resolved_scope,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
81
api/core/app/file_access/protocols.py
Normal file
81
api/core/app/file_access/protocols.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .scope import FileAccessScope
|
||||
|
||||
|
||||
class FileAccessControllerProtocol(Protocol):
|
||||
"""Contract for applying access rules to file lookups.
|
||||
|
||||
Implementations translate an optional execution scope into query constraints
|
||||
and authorized record retrieval. The contract is intentionally limited to
|
||||
ownership and tenancy rules for workflow-layer file access.
|
||||
"""
|
||||
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
"""Return the scope active for the current execution, if one exists.
|
||||
|
||||
Callers use this to decide whether embedded file metadata may be trusted
|
||||
or whether a fresh authorized lookup is required.
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[UploadFile]]:
|
||||
"""Return an upload-file query constrained by the supplied access scope.
|
||||
|
||||
The returned statement must preserve the caller's existing predicates and
|
||||
append only access-control conditions.
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
*,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> Select[tuple[ToolFile]]:
|
||||
"""Return a tool-file query constrained by the supplied access scope.
|
||||
|
||||
The returned statement must preserve the caller's existing predicates and
|
||||
append only access-control conditions.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> UploadFile | None:
|
||||
"""Load one authorized upload-file record for the given identifier.
|
||||
|
||||
Returns ``None`` when the file does not exist or when the scope does not
|
||||
permit access to that record.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
file_id: str,
|
||||
scope: FileAccessScope | None = None,
|
||||
) -> ToolFile | None:
|
||||
"""Load one authorized tool-file record for the given identifier.
|
||||
|
||||
Returns ``None`` when the file does not exist or when the scope does not
|
||||
permit access to that record.
|
||||
"""
|
||||
...
|
||||
40
api/core/app/file_access/scope.py
Normal file
40
api/core/app/file_access/scope.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
|
||||
_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
|
||||
"current_file_access_scope",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileAccessScope:
|
||||
"""Request-scoped ownership context used by workflow-layer file lookups."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
@property
|
||||
def requires_user_ownership(self) -> bool:
|
||||
return self.user_from == UserFrom.END_USER
|
||||
|
||||
|
||||
def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
return _current_file_access_scope.get()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_file_access_scope.reset(token)
|
||||
@@ -1,12 +1,19 @@
|
||||
"""
|
||||
Persist conversation-scoped variable updates emitted by the graph engine.
|
||||
|
||||
The graph package emits generic variable update events and stays unaware of
|
||||
conversation identity or storage concerns. This layer lives in the application
|
||||
core, listens to those generic events, and persists only the `conversation.*`
|
||||
scope updates that matter to chat applications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from dify_graph.variables import VariableBase
|
||||
from dify_graph.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,41 +27,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
if not isinstance(event, NodeRunVariableUpdatedEvent):
|
||||
return
|
||||
|
||||
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
|
||||
if not updated_variables:
|
||||
selector = event.variable.selector
|
||||
if len(selector) < 2:
|
||||
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||
return
|
||||
|
||||
conversation_id = self.graph_runtime_state.system_variable.conversation_id
|
||||
conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID)
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
updated_any = False
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < 2:
|
||||
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||
continue
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, VariableBase):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
)
|
||||
continue
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
|
||||
updated_any = True
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
return
|
||||
|
||||
if updated_any:
|
||||
self._conversation_variable_updater.flush()
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events.base import GraphEngineEvent
|
||||
from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
@@ -119,7 +120,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
generate_entity=entity_wrapper,
|
||||
)
|
||||
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
workflow_run_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID,
|
||||
)
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, ClassVar
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events.base import GraphEngineEvent
|
||||
from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
@@ -59,7 +60,10 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
workflow_run_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID,
|
||||
)
|
||||
assert workflow_run_id, "Workflow run id is not set"
|
||||
|
||||
total_tokens = self.graph_runtime_state.total_tokens
|
||||
|
||||
@@ -2,23 +2,34 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
|
||||
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_manager = provider_manager or ProviderManager()
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_context: DifyRunContext,
|
||||
provider_manager: ProviderManager | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = run_context.tenant_id
|
||||
if provider_manager is None:
|
||||
provider_manager = create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
@@ -42,9 +53,21 @@ class DifyModelFactory:
|
||||
tenant_id: str
|
||||
model_manager: ModelManager
|
||||
|
||||
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_context: DifyRunContext,
|
||||
model_manager: ModelManager | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = run_context.tenant_id
|
||||
if model_manager is None:
|
||||
model_manager = ModelManager(
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
return self.model_manager.get_model_instance(
|
||||
@@ -55,18 +78,42 @@ class DifyModelFactory:
|
||||
)
|
||||
|
||||
|
||||
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
|
||||
return (
|
||||
DifyCredentialsProvider(tenant_id=tenant_id),
|
||||
DifyModelFactory(tenant_id=tenant_id),
|
||||
def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]:
|
||||
"""Create LLM access adapters that share the same tenant-bound manager graph."""
|
||||
provider_manager = create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
DifyModelFactory(run_context=run_context, model_manager=model_manager),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Split node-level completion params into provider parameters and stop sequences.
|
||||
|
||||
Workflow LLM-compatible nodes still consume runtime invocation settings from
|
||||
``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the
|
||||
``ModelInstance`` view and the returned config entity aligned here so callers
|
||||
do not need to duplicate normalization logic.
|
||||
"""
|
||||
normalized_parameters = dict(completion_params)
|
||||
stop = normalized_parameters.pop("stop", [])
|
||||
if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop):
|
||||
stop = []
|
||||
|
||||
return normalized_parameters, stop
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_factory: DifyModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
@@ -80,22 +127,18 @@ def fetch_model_config(
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
if model_schema is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.")
|
||||
|
||||
parameters, stop = _normalize_completion_params(node_data_model.completion_params)
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.parameters = parameters
|
||||
model_instance.stop = tuple(stop)
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
@@ -103,8 +146,8 @@ def fetch_model_config(
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=completion_params,
|
||||
parameters=parameters,
|
||||
stop=stop,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
@@ -1,33 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from dify_graph.file.runtime import set_workflow_file_runtime
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
"""Production runtime wiring for ``dify_graph.file``."""
|
||||
"""Production runtime wiring for ``dify_graph.file``.
|
||||
|
||||
@property
|
||||
def files_url(self) -> str:
|
||||
return dify_config.FILES_URL
|
||||
Opaque file references are resolved back to canonical database records before
|
||||
URLs are signed or storage keys are used. When a request-scoped file access
|
||||
scope is present, those lookups additionally enforce tenant and end-user
|
||||
ownership filters.
|
||||
"""
|
||||
|
||||
@property
|
||||
def internal_files_url(self) -> str | None:
|
||||
return dify_config.INTERNAL_FILES_URL
|
||||
_file_access_controller: FileAccessControllerProtocol
|
||||
|
||||
@property
|
||||
def secret_key(self) -> str:
|
||||
return dify_config.SECRET_KEY
|
||||
|
||||
@property
|
||||
def files_access_timeout(self) -> int:
|
||||
return dify_config.FILES_ACCESS_TIMEOUT
|
||||
def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None:
|
||||
self._file_access_controller = file_access_controller
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str:
|
||||
@@ -39,9 +48,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
def load_file_bytes(self, *, file: File) -> bytes:
|
||||
storage_key = self._resolve_storage_key(file=file)
|
||||
data = storage.load(storage_key, stream=False)
|
||||
if not isinstance(data, bytes):
|
||||
raise ValueError(f"file {storage_key} is not a bytes object")
|
||||
return data
|
||||
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("Missing file reference")
|
||||
if file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return self.resolve_upload_file_url(
|
||||
upload_file_id=parsed_reference.record_id,
|
||||
for_external=for_external,
|
||||
)
|
||||
if file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
|
||||
if file.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
self._assert_upload_file_access(upload_file_id=parsed_reference.record_id)
|
||||
return sign_tool_file(
|
||||
tool_file_id=parsed_reference.record_id,
|
||||
extension=file.extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
if file.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return self.resolve_tool_file_url(
|
||||
tool_file_id=parsed_reference.record_id,
|
||||
extension=file.extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
return None
|
||||
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
upload_file_id: str,
|
||||
as_attachment: bool = False,
|
||||
for_external: bool = True,
|
||||
) -> str:
|
||||
self._assert_upload_file_access(upload_file_id=upload_file_id)
|
||||
base_url = self._base_url(for_external=for_external)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
query = self._sign_query(payload=f"file-preview|{upload_file_id}")
|
||||
if as_attachment:
|
||||
query["as_attachment"] = "true"
|
||||
return f"{url}?{urllib.parse.urlencode(query)}"
|
||||
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._assert_tool_file_access(tool_file_id=tool_file_id)
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
preview_kind: Literal["image", "file"],
|
||||
file_id: str,
|
||||
timestamp: str,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}"
|
||||
recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest()
|
||||
if sign != base64.urlsafe_b64encode(recalculated).decode():
|
||||
return False
|
||||
return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def _base_url(*, for_external: bool) -> str:
|
||||
if for_external:
|
||||
return dify_config.FILES_URL
|
||||
return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
@staticmethod
|
||||
def _secret_key() -> bytes:
|
||||
return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
|
||||
def _sign_query(self, *, payload: str) -> dict[str, str]:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest()
|
||||
return {
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"sign": base64.urlsafe_b64encode(sign).decode(),
|
||||
}
|
||||
|
||||
def _resolve_storage_key(self, *, file: File) -> str:
|
||||
parsed_reference = parse_file_reference(file.reference)
|
||||
if parsed_reference is None:
|
||||
raise ValueError("Missing file reference")
|
||||
|
||||
record_id = parsed_reference.record_id
|
||||
with session_factory.create_session() as session:
|
||||
if file.transfer_method in {
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
}:
|
||||
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id)
|
||||
if upload_file is None:
|
||||
raise ValueError(f"Upload file {record_id} not found")
|
||||
return upload_file.key
|
||||
|
||||
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"Tool file {record_id} not found")
|
||||
return tool_file.file_key
|
||||
|
||||
def _assert_upload_file_access(self, *, upload_file_id: str) -> None:
|
||||
if self._file_access_controller.current_scope() is None:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id)
|
||||
if upload_file is None:
|
||||
raise ValueError(f"Upload file {upload_file_id} not found")
|
||||
|
||||
def _assert_tool_file_access(self, *, tool_file_id: str) -> None:
|
||||
if self._file_access_controller.current_scope() is None:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"Tool file {tool_file_id} not found")
|
||||
|
||||
|
||||
def bind_dify_workflow_file_runtime() -> None:
|
||||
set_workflow_file_runtime(DifyWorkflowFileRuntime())
|
||||
set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()))
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
try:
|
||||
dify_ctx = node.require_dify_context()
|
||||
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
deduct_llm_quota(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
@@ -114,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
model_instance = cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
@@ -127,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
if isinstance(model_instance, ModelInstance):
|
||||
return model_instance
|
||||
|
||||
raw_model_instance = getattr(model_instance, "_model_instance", None)
|
||||
if isinstance(raw_model_instance, ModelInstance):
|
||||
return raw_model_instance
|
||||
|
||||
return None
|
||||
|
||||
@@ -17,10 +17,12 @@ from typing import Any, Union
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variables import SystemVariableKey
|
||||
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run
|
||||
from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
@@ -43,8 +45,6 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@@ -372,10 +372,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
domain_execution.error = error
|
||||
|
||||
if update_outputs:
|
||||
projected_outputs = project_node_outputs_for_workflow_run(
|
||||
node_type=domain_execution.node_type,
|
||||
inputs=node_result.inputs,
|
||||
outputs=node_result.outputs,
|
||||
)
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=node_result.inputs,
|
||||
process_data=node_result.process_data,
|
||||
outputs=node_result.outputs,
|
||||
outputs=projected_outputs,
|
||||
metadata=node_result.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,12 +25,10 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
|
||||
|
||||
|
||||
def _process_future(
|
||||
@@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher:
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts")
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
@@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher:
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
@@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher:
|
||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||
self.max_sentence += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if isinstance(text_tmp, str):
|
||||
self.msg_text = text_tmp
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@@ -24,10 +25,11 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file import File, get_file_type_by_mime_type
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from factories import file_factory
|
||||
@@ -36,6 +38,7 @@ from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
@@ -279,11 +282,15 @@ class DatasourceManager:
|
||||
if datasource_file is not None:
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(mime_type),
|
||||
"type": get_file_type_by_mime_type(mime_type),
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"url": url,
|
||||
}
|
||||
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
file_out = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif mtype == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
||||
@@ -351,11 +358,10 @@ class DatasourceManager:
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
|
||||
@@ -4,6 +4,7 @@ from mimetypes import guess_extension, guess_type
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from models.tools import ToolFile
|
||||
|
||||
@@ -103,8 +104,14 @@ class DatasourceFileMessageTransformer:
|
||||
file: File | None = meta.get("file")
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||
reference = getattr(file, "reference", None) or getattr(file, "related_id", None)
|
||||
parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None
|
||||
if parsed_reference is None:
|
||||
raise ValueError("datasource file is missing reference")
|
||||
url = cls.get_datasource_file_url(
|
||||
datasource_file_id=parsed_reference.record_id,
|
||||
extension=file.extension,
|
||||
)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from enum import StrEnum, auto
|
||||
"""Compatibility wrapper for the runtime embedding input enum."""
|
||||
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType
|
||||
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""
|
||||
Enum for embedding input type.
|
||||
"""
|
||||
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
__all__ = ["EmbeddingInputType"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -5,7 +7,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -19,6 +21,7 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@@ -28,6 +31,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel):
|
||||
- Load balancing configurations
|
||||
- Model enablement/disablement
|
||||
|
||||
Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so
|
||||
nested schema and model lookups reuse the caller scope that was already
|
||||
resolved by the composition layer.
|
||||
|
||||
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
|
||||
"""
|
||||
|
||||
@@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
_bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
@@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
return self
|
||||
|
||||
def bind_model_runtime(self, model_runtime: ModelRuntime) -> None:
|
||||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
|
||||
def get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Return a provider factory that preserves any request-bound runtime."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
@@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
@@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
@@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
|
||||
model_types: list[ModelType] = []
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
@@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
|
||||
@@ -50,7 +50,10 @@ logger = logging.getLogger(__name__)
|
||||
class IndexingRunner:
|
||||
def __init__(self):
|
||||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
@staticmethod
|
||||
def _get_model_manager(tenant_id: str) -> ModelManager:
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
|
||||
"""Handle indexing errors by updating document status."""
|
||||
@@ -291,20 +294,20 @@ class IndexingRunner:
|
||||
raise ValueError("Dataset not found.")
|
||||
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
@@ -574,7 +577,7 @@ class IndexingRunner:
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@@ -766,14 +769,14 @@ class IndexingRunner:
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
@@ -62,7 +62,7 @@ class LLMGenerator:
|
||||
|
||||
prompt += query + "\n"
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -120,7 +120,7 @@ class LLMGenerator:
|
||||
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
|
||||
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -172,7 +172,7 @@ class LLMGenerator:
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt_generate)]
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@@ -219,7 +219,7 @@ class LLMGenerator:
|
||||
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -306,7 +306,7 @@ class LLMGenerator:
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -337,7 +337,7 @@ class LLMGenerator:
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -362,7 +362,7 @@ class LLMGenerator:
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -536,7 +536,7 @@ class LLMGenerator:
|
||||
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
|
||||
if ERROR_MESSAGE in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
|
||||
@@ -55,7 +55,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
@@ -70,7 +69,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
@@ -85,7 +83,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
@@ -99,7 +96,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
@@ -113,7 +109,6 @@ def invoke_llm_with_structured_output(
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -143,7 +138,6 @@ def invoke_llm_with_structured_output(
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from dify_graph.file import file_manager
|
||||
@@ -23,6 +24,8 @@ from models.workflow import Workflow
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
def __init__(
|
||||
@@ -85,7 +88,10 @@ class TokenBufferMemory:
|
||||
# Build files directly without filtering by belongs_to
|
||||
file_objs = [
|
||||
file_factory.build_from_message_file(
|
||||
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
message_file=message_file,
|
||||
tenant_id=app_record.tenant_id,
|
||||
config=file_extra_config,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
for message_file in message_files
|
||||
]
|
||||
|
||||
@@ -7,11 +7,12 @@ from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelInstance:
|
||||
"""
|
||||
Model instance class
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
@@ -49,6 +50,13 @@ class ModelInstance:
|
||||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
def get_model_schema(self) -> AIModelEntity:
|
||||
"""Return the resolved schema for the current model instance."""
|
||||
model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"model schema not found for {self.model_name}")
|
||||
return model_schema
|
||||
|
||||
@staticmethod
|
||||
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str):
|
||||
"""
|
||||
@@ -110,7 +118,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@@ -122,7 +129,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@@ -134,7 +140,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
@@ -145,7 +150,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@@ -156,7 +160,6 @@ class ModelInstance:
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -173,7 +176,6 @@ class ModelInstance:
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
)
|
||||
@@ -202,13 +204,12 @@ class ModelInstance:
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
@@ -221,7 +222,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
@@ -229,14 +229,12 @@ class ModelInstance:
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
multimodel_documents: list[dict],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param multimodel_documents: multimodel documents to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
@@ -249,7 +247,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
@@ -279,7 +276,6 @@ class ModelInstance:
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -288,7 +284,6 @@ class ModelInstance:
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
@@ -303,7 +298,6 @@ class ModelInstance:
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -313,7 +307,6 @@ class ModelInstance:
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -322,7 +315,6 @@ class ModelInstance:
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
@@ -337,16 +329,14 @@ class ModelInstance:
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
@@ -358,16 +348,14 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
@@ -379,18 +367,15 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
@@ -402,8 +387,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
),
|
||||
)
|
||||
@@ -477,10 +460,20 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self):
|
||||
self._provider_manager = ProviderManager()
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
def get_model_instance(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
) -> ModelInstance:
|
||||
"""
|
||||
Get model instance
|
||||
:param tenant_id: tenant id
|
||||
@@ -496,7 +489,8 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@@ -50,7 +50,7 @@ class OpenAIModeration(Moderation):
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
|
||||
)
|
||||
|
||||
@@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
|
||||
return workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
|
||||
|
||||
@@ -271,8 +271,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
|
||||
@@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
|
||||
@@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
|
||||
@@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
|
||||
executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id)
|
||||
return list(executions)
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
# rearrange workflow_node_executions by starting time
|
||||
|
||||
@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@staticmethod
|
||||
def _get_bound_model_instance(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
):
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_text_embedding(texts=payload.texts)
|
||||
|
||||
return response
|
||||
|
||||
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_speech2text(file=temp)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
@@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_moderation(text=payload.text)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
|
||||
def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int:
|
||||
"""
|
||||
get system model max tokens
|
||||
"""
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
|
||||
"""
|
||||
get prompt tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=tenant_id,
|
||||
prompt_messages=prompt_messages,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_system_model(
|
||||
@@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tool_type=ToolProviderType.PLUGIN,
|
||||
tool_name="plugin",
|
||||
prompt_messages=prompt_messages,
|
||||
caller_user_id=user_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke summary
|
||||
"""
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id)
|
||||
content = payload.text
|
||||
|
||||
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
@@ -325,6 +337,7 @@ Here is the extra instruction you need to follow:
|
||||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=content)],
|
||||
user_id=user_id,
|
||||
)
|
||||
< max_tokens * 0.6
|
||||
):
|
||||
@@ -337,6 +350,7 @@ Here is the extra instruction you need to follow:
|
||||
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
|
||||
UserPromptMessage(content=content),
|
||||
],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def summarize(content: str) -> str:
|
||||
@@ -394,6 +408,7 @@ Here is the extra instruction you need to follow:
|
||||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=result)],
|
||||
user_id=user_id,
|
||||
)
|
||||
> max_tokens * 0.7
|
||||
):
|
||||
|
||||
@@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||
tool_type,
|
||||
tenant_id,
|
||||
provider,
|
||||
tool_name,
|
||||
tool_parameters,
|
||||
user_id=user_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
response = ToolEngine.generic_invoke(
|
||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO
|
||||
from typing import IO, Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
@@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
@staticmethod
|
||||
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"data": data}
|
||||
if user_id is not None:
|
||||
payload["user_id"] = user_id
|
||||
return payload
|
||||
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
@@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/schema",
|
||||
PluginModelSchemaEntity,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_llm(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
@@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
@@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"prompt_messages": prompt_messages,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"documents": documents,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
@@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"content_text": content_text,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_moderation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
|
||||
499
api/core/plugin/impl/model_runtime.py
Normal file
499
api/core/plugin/impl/model_runtime.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# `TS` means tenant scope
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
client: PluginModelClient
|
||||
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||
_provider_entities_lock: Lock
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||
if client is None:
|
||||
raise ValueError("client is required.")
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.client = client
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
|
||||
with self._provider_entities_lock:
|
||||
if self._provider_entities is None:
|
||||
self._provider_entities = tuple(
|
||||
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||
)
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
cache_key = self._get_schema_cache_key(
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
schema = self.client.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
) -> int:
|
||||
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
return 0
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict[str, Any]],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Iterable[bytes]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None,
|
||||
) -> Any:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
providers = self.fetch_model_providers()
|
||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||
if provider_entity is None:
|
||||
provider_entity = next((item for item in providers if provider == item.provider_name), None)
|
||||
if provider_entity is None:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
return provider_entity
|
||||
|
||||
def _get_schema_cache_key(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> str:
|
||||
# The plugin daemon distinguishes ``None`` from an explicit empty-string
|
||||
# caller id, so the cache must only collapse ``None`` into tenant scope.
|
||||
cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id
|
||||
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
if not sorted_credentials:
|
||||
return cache_key
|
||||
hashed_credentials = ":".join(
|
||||
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
|
||||
)
|
||||
return f"{cache_key}:{hashed_credentials}"
|
||||
|
||||
def _split_provider(self, provider: str) -> tuple[str, str]:
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
89
api/core/plugin/impl/model_runtime_factory.py
Normal file
89
api/core/plugin/impl/model_runtime_factory.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
_model_runtime: PluginModelRuntime | None
|
||||
_model_provider_factory: ModelProviderFactory | None
|
||||
_provider_manager: ProviderManager | None
|
||||
_model_manager: ModelManager | None
|
||||
|
||||
def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self._model_runtime = None
|
||||
self._model_provider_factory = None
|
||||
self._provider_manager = None
|
||||
self._model_manager = None
|
||||
|
||||
@property
|
||||
def model_runtime(self) -> PluginModelRuntime:
|
||||
if self._model_runtime is None:
|
||||
self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id)
|
||||
return self._model_runtime
|
||||
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
if self._provider_manager is None:
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
self._provider_manager = ProviderManager(model_runtime=self.model_runtime)
|
||||
return self._provider_manager
|
||||
|
||||
@property
|
||||
def model_manager(self) -> ModelManager:
|
||||
if self._model_manager is None:
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
self._model_manager = ModelManager(provider_manager=self.provider_manager)
|
||||
return self._model_manager
|
||||
|
||||
|
||||
def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly:
|
||||
"""Create a request-scoped assembly that shares one plugin runtime across model views."""
|
||||
return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
|
||||
"""Create a plugin runtime with its client dependency fully composed."""
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
|
||||
return PluginModelRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
client=PluginModelClient(),
|
||||
)
|
||||
|
||||
|
||||
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
|
||||
"""Create a tenant-bound model provider factory for service flows."""
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory
|
||||
|
||||
|
||||
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
|
||||
"""Create a tenant-bound provider manager for service flows."""
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager
|
||||
|
||||
|
||||
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
|
||||
"""Create a tenant-bound model manager for service flows."""
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager
|
||||
@@ -1,50 +1,7 @@
|
||||
from typing import Literal
|
||||
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
"""
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
|
||||
text: str
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
"""
|
||||
|
||||
class RolePrefix(BaseModel):
|
||||
"""
|
||||
Role Prefix.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
class WindowConfig(BaseModel):
|
||||
"""
|
||||
Window Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
__all__ = [
|
||||
"ChatModelMessage",
|
||||
"CompletionModelPromptTemplate",
|
||||
"MemoryConfig",
|
||||
]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -53,15 +55,25 @@ from models.provider import (
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
ProviderManager manages tenant-scoped model provider configuration.
|
||||
|
||||
The runtime adapter is injected by the composition layer so this class stays
|
||||
focused on configuration assembly instead of constructing plugin runtimes.
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@@ -127,7 +139,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
@@ -255,6 +267,7 @@ class ProviderManager:
|
||||
custom_configuration=custom_configuration,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
provider_configuration.bind_model_runtime(self._model_runtime)
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
@@ -321,7 +334,7 @@ class ProviderManager:
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
|
||||
|
||||
return DefaultModelEntity(
|
||||
@@ -392,7 +405,7 @@ class ProviderManager:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
provider_name=provider,
|
||||
model_name=model,
|
||||
)
|
||||
|
||||
@@ -52,11 +52,10 @@ class DataPostProcessor:
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
if self.rerank_runner:
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type)
|
||||
|
||||
if self.reorder_runner:
|
||||
documents = self.reorder_runner.run(documents)
|
||||
@@ -106,9 +105,9 @@ class DataPostProcessor:
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
||||
@@ -328,7 +328,7 @@ class RetrievalService:
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
|
||||
@@ -303,7 +303,7 @@ class Vector:
|
||||
redis_client.delete(collection_exist_cache_key)
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@@ -73,7 +73,7 @@ class DatasetDocumentStore:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
|
||||
@@ -21,9 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance, user: str | None = None):
|
||||
def __init__(self, model_instance: ModelInstance):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
@@ -65,7 +64,7 @@ class CacheEmbedding(Embeddings):
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
@@ -147,7 +146,6 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=batch_multimodel_documents,
|
||||
user=self._user,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
@@ -202,7 +200,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
texts=[text], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
@@ -245,7 +243,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
|
||||
@@ -8,11 +8,12 @@ from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
@@ -27,6 +28,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor, Su
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
@@ -48,6 +50,8 @@ from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
@@ -410,7 +414,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
# If default prompt doesn't have {language} placeholder, use it as-is
|
||||
pass
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
@@ -555,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
file_obj = build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
@@ -604,11 +609,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
reference=build_file_reference(
|
||||
record_id=str(upload_file.id),
|
||||
),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ class BaseRerankRunner(ABC):
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class BaseRerankRunner(ABC):
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -22,7 +22,6 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
@@ -31,10 +30,11 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
|
||||
)
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
@@ -43,12 +43,12 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
)
|
||||
if not is_support_vision:
|
||||
if query_type == QueryType.TEXT_QUERY:
|
||||
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
||||
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n)
|
||||
else:
|
||||
return documents
|
||||
else:
|
||||
rerank_result, unique_documents = self.fetch_multimodal_rerank(
|
||||
query, documents, score_threshold, top_n, user, query_type
|
||||
query, documents, score_threshold, top_n, query_type
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
@@ -73,7 +73,6 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> tuple[RerankResult, list[Document]]:
|
||||
"""
|
||||
Fetch text rerank
|
||||
@@ -81,7 +80,6 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
@@ -103,7 +101,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
unique_documents.append(document)
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
|
||||
@@ -113,7 +111,6 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> tuple[RerankResult, list[Document]]:
|
||||
"""
|
||||
@@ -122,7 +119,6 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:param query_type: query type
|
||||
:return: rerank result
|
||||
"""
|
||||
@@ -168,7 +164,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
|
||||
documents = unique_documents
|
||||
if query_type == QueryType.TEXT_QUERY:
|
||||
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
||||
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n)
|
||||
return rerank_result, unique_documents
|
||||
elif query_type == QueryType.IMAGE_QUERY:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
@@ -181,7 +177,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
@@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
|
||||
:return:
|
||||
"""
|
||||
@@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
"""
|
||||
query_vector_scores = []
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import (
|
||||
KnowledgeRetrievalRequest,
|
||||
@@ -160,7 +161,7 @@ class DatasetRetrieval:
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -383,23 +384,27 @@ class DatasetRetrieval:
|
||||
return None, []
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get model schema
|
||||
# Reuse the caller-bound model instance for both schema resolution and
|
||||
# downstream planner/invoke calls so a single request never mixes
|
||||
# tenant-scope and request-bound runtimes.
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
model=model_instance.model_name,
|
||||
credentials=model_instance.credentials,
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None, []
|
||||
|
||||
model_config.provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_config.credentials = model_instance.credentials
|
||||
model_config.model_schema = model_schema
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
@@ -517,11 +522,12 @@ class DatasetRetrieval:
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=segment.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
reference=build_file_reference(
|
||||
record_id=str(upload_file.id),
|
||||
),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
@@ -986,6 +992,24 @@ class DatasetRetrieval:
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None:
|
||||
"""Map runtime user source values to dataset query audit roles.
|
||||
|
||||
Workflow run context uses the hyphenated ``end-user`` value, while
|
||||
``DatasetQuery.created_by_role`` persists the underscore-based
|
||||
``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an
|
||||
unsupported value should be skipped instead of aborting retrieval.
|
||||
"""
|
||||
normalized_user_from = str(user_from).strip().lower().replace("-", "_")
|
||||
if normalized_user_from == CreatorUserRole.ACCOUNT.value:
|
||||
return CreatorUserRole.ACCOUNT
|
||||
if normalized_user_from == CreatorUserRole.END_USER.value:
|
||||
return CreatorUserRole.END_USER
|
||||
|
||||
logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from)
|
||||
return None
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
query: str | None,
|
||||
@@ -996,10 +1020,13 @@ class DatasetRetrieval:
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Handle query.
|
||||
Persist dataset query audit rows for retrieval requests.
|
||||
"""
|
||||
if not query and not attachment_ids:
|
||||
return
|
||||
created_by_role = self._resolve_creator_user_role(user_from)
|
||||
if created_by_role is None:
|
||||
return
|
||||
dataset_queries = []
|
||||
for dataset_id in dataset_ids:
|
||||
contents = []
|
||||
@@ -1014,7 +1041,7 @@ class DatasetRetrieval:
|
||||
content=json.dumps(contents),
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=app_id,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by_role=created_by_role,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
@@ -1411,7 +1438,7 @@ class DatasetRetrieval:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._get_prompt_template(
|
||||
@@ -1430,7 +1457,6 @@ class DatasetRetrieval:
|
||||
model_parameters=model_config.parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1533,7 +1559,7 @@ class DatasetRetrieval:
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
self, tenant_id: str, model: ModelConfig
|
||||
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
@@ -1543,7 +1569,7 @@ class DatasetRetrieval:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
@@ -3,13 +3,14 @@ from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@@ -119,6 +120,7 @@ class ReactMultiDatasetRouter:
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
@@ -150,19 +152,24 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@@ -4,7 +4,13 @@ from __future__ import annotations
|
||||
|
||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from .factory import (
|
||||
DifyCoreRepositoryFactory,
|
||||
OrderConfig,
|
||||
RepositoryImportError,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
@@ -12,7 +18,10 @@ __all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"OrderConfig",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowExecutionRepository",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
"WorkflowExecutionRepository",
|
||||
"WorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
||||
@@ -11,8 +11,8 @@ from typing import Union
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import WorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
@@ -12,11 +12,11 @@ from typing import Union
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from dify_graph.repositories.workflow_node_execution_repository import (
|
||||
from core.repositories.factory import (
|
||||
OrderConfig,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
@@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
||||
def get_by_workflow_run(
|
||||
def get_by_workflow_execution(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_execution_id: str,
|
||||
order_config: OrderConfig | None = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
|
||||
Retrieve all workflow node executions for a workflow execution from cache.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
workflow_execution_id: The workflow execution identifier
|
||||
order_config: Optional configuration for ordering results
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecution instances
|
||||
"""
|
||||
try:
|
||||
# Get execution IDs for this workflow run from cache
|
||||
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
|
||||
# Get execution IDs for this workflow execution from cache
|
||||
execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, [])
|
||||
|
||||
# Retrieve executions from cache
|
||||
result = []
|
||||
@@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
for field_name in reversed(order_config.order_by):
|
||||
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
|
||||
|
||||
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
|
||||
logger.debug(
|
||||
"Retrieved %d workflow node executions for execution %s from cache",
|
||||
len(result),
|
||||
workflow_execution_id,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
|
||||
logger.exception(
|
||||
"Failed to get workflow node executions for execution %s from cache",
|
||||
workflow_execution_id,
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation
|
||||
allowing users to configure different repository backends through string paths.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Protocol, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from libs.module_loading import import_string
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderConfig:
|
||||
"""Configuration for ordering node execution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Literal["asc", "desc"] | None = None
|
||||
|
||||
|
||||
class WorkflowExecutionRepository(Protocol):
|
||||
def save(self, execution: WorkflowExecution): ...
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
def save(self, execution: WorkflowNodeExecution): ...
|
||||
|
||||
def save_execution_data(self, execution: WorkflowNodeExecution): ...
|
||||
|
||||
def get_by_workflow_execution(
|
||||
self,
|
||||
workflow_execution_id: str,
|
||||
order_config: OrderConfig | None = None,
|
||||
) -> Sequence[WorkflowNodeExecution]: ...
|
||||
|
||||
|
||||
class RepositoryImportError(Exception):
|
||||
"""Raised when a repository implementation cannot be imported or instantiated."""
|
||||
|
||||
|
||||
@@ -2,33 +2,23 @@ import dataclasses
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Protocol
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
from core.workflow.human_input_compat import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import (
|
||||
DeliveryMethodType,
|
||||
HumanInputFormKind,
|
||||
HumanInputFormStatus,
|
||||
)
|
||||
from dify_graph.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
InteractiveSurfaceDeliveryMethod,
|
||||
is_human_input_webapp_enabled,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, HumanInputNodeData
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account, TenantAccountJoin
|
||||
@@ -36,6 +26,7 @@ from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
ConsoleDeliveryPayload,
|
||||
ConsoleRecipientPayload,
|
||||
DeliveryMethodType,
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
@@ -58,6 +49,65 @@ class _WorkspaceMemberInfo:
|
||||
email: str
|
||||
|
||||
|
||||
class FormNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormCreateParams:
|
||||
workflow_execution_id: str | None
|
||||
node_id: str
|
||||
form_config: HumanInputNodeData
|
||||
rendered_content: str
|
||||
delivery_methods: Sequence[DeliveryChannelConfig]
|
||||
display_in_ui: bool
|
||||
resolved_default_values: Mapping[str, Any]
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
|
||||
|
||||
class HumanInputFormRecipientEntity(Protocol):
|
||||
@property
|
||||
def id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def token(self) -> str: ...
|
||||
|
||||
|
||||
class HumanInputFormEntity(Protocol):
|
||||
@property
|
||||
def id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def submission_token(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]: ...
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str: ...
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus: ...
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime: ...
|
||||
|
||||
|
||||
class HumanInputFormRepository(Protocol):
|
||||
def get_form(self, node_id: str) -> HumanInputFormEntity | None: ...
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ...
|
||||
|
||||
|
||||
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
||||
def __init__(self, recipient_model: HumanInputFormRecipient):
|
||||
self._recipient_model = recipient_model
|
||||
@@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
|
||||
self._form_model = form_model
|
||||
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
|
||||
self._web_app_recipient = next(
|
||||
self._interactive_surface_recipient = next(
|
||||
(
|
||||
recipient
|
||||
for recipient in recipient_models
|
||||
@@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
return self._form_model.id
|
||||
|
||||
@property
|
||||
def web_app_token(self):
|
||||
def submission_token(self) -> str | None:
|
||||
if self._console_recipient is not None:
|
||||
return self._console_recipient.access_token
|
||||
if self._web_app_recipient is None:
|
||||
if self._interactive_surface_recipient is None:
|
||||
return None
|
||||
return self._web_app_recipient.access_token
|
||||
return self._interactive_surface_recipient.access_token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
@@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl:
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
):
|
||||
app_id: str | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
invoke_source: str | None = None,
|
||||
submission_actor_id: str | None = None,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
self._invoke_source = invoke_source
|
||||
self._submission_actor_id = submission_actor_id
|
||||
|
||||
def _delivery_method_to_model(
|
||||
self,
|
||||
@@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl:
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
||||
if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod):
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
@@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl:
|
||||
delivery_id: str,
|
||||
recipients_config: EmailRecipients,
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
member_user_ids = [
|
||||
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
|
||||
bound_reference_ids = [
|
||||
recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient)
|
||||
]
|
||||
external_emails = [
|
||||
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
|
||||
]
|
||||
if recipients_config.whole_workspace:
|
||||
if recipients_config.include_bound_group:
|
||||
members = self._query_all_workspace_members(session=session)
|
||||
else:
|
||||
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
|
||||
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids)
|
||||
|
||||
return self._create_email_recipients_from_resolved(
|
||||
form_id=form_id,
|
||||
@@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl:
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def _should_create_console_recipient(
|
||||
self,
|
||||
*,
|
||||
form_config: HumanInputNodeData,
|
||||
form_kind: HumanInputFormKind,
|
||||
) -> bool:
|
||||
if form_kind != HumanInputFormKind.RUNTIME:
|
||||
return False
|
||||
if self._invoke_source == "debugger":
|
||||
return True
|
||||
if self._invoke_source == "explore":
|
||||
return is_human_input_webapp_enabled(form_config)
|
||||
return False
|
||||
|
||||
def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool:
|
||||
return form_kind == HumanInputFormKind.RUNTIME and (
|
||||
self._invoke_source is not None or self._submission_actor_id is not None
|
||||
)
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
app_id = self._app_id
|
||||
if not app_id:
|
||||
raise ValueError("app_id is required to create a human input form")
|
||||
workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id
|
||||
if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None:
|
||||
raise ValueError("workflow_execution_id is required for runtime human input forms")
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# Generate unique form ID
|
||||
@@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl:
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=params.app_id,
|
||||
workflow_run_id=params.workflow_execution_id,
|
||||
app_id=app_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
form_kind=params.form_kind,
|
||||
node_id=params.node_id,
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
@@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl:
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
recipient_models.extend(delivery_and_recipients.recipients)
|
||||
if params.console_recipient_required and not any(
|
||||
if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any(
|
||||
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
|
||||
):
|
||||
console_delivery_id = str(uuidv7())
|
||||
@@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl:
|
||||
delivery_id=console_delivery_id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
account_id=self._submission_actor_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(console_delivery)
|
||||
session.add(console_recipient)
|
||||
recipient_models.append(console_recipient)
|
||||
if params.backstage_recipient_required and not any(
|
||||
if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any(
|
||||
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
|
||||
):
|
||||
backstage_delivery_id = str(uuidv7())
|
||||
@@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl:
|
||||
delivery_id=backstage_delivery_id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
account_id=self._submission_actor_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(backstage_delivery)
|
||||
@@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl:
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
def get_form(self, node_id: str) -> HumanInputFormEntity | None:
|
||||
if self._workflow_execution_id is None:
|
||||
raise ValueError("workflow_execution_id is required to load runtime human input forms")
|
||||
|
||||
form_query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.workflow_run_id == self._workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
HumanInputForm.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
@@ -9,9 +9,9 @@ from typing import Union
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import WorkflowExecutionRepository
|
||||
from dify_graph.entities import WorkflowExecution
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user