refactor(api): continue decoupling dify_graph from API concerns (#33580)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-03-25 20:32:24 +08:00
committed by GitHub
parent b7b9b003c9
commit 56593f20b0
487 changed files with 17999 additions and 9186 deletions

View File

@@ -1,10 +1,14 @@
[importlinter]
root_packages =
core
constants
context
dify_graph
configs
controllers
extensions
factories
libs
models
tasks
services
@@ -33,29 +37,19 @@ ignore_imports =
# TODO(QuantumGhost): fix the import violation later
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
dify_graph
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
type = forbidden
source_modules =
dify_graph
forbidden_modules =
constants
configs
context
controllers
extensions
factories
libs
models
services
tasks
@@ -88,46 +82,14 @@ forbidden_modules =
core.tools
core.trigger
core.variables
ignore_imports =
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
dify_graph.nodes.llm.node -> core.tools.signature
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.tool.tool_node -> services
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.__base.large_language_model -> configs
dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
dify_graph.model_runtime.model_providers.model_provider_factory -> configs
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids
[importlinter:contract:workflow-third-party-imports]
name = Workflow Third-Party Imports
type = forbidden
source_modules =
dify_graph
forbidden_modules =
sqlalchemy
[importlinter:contract:rsc]
name = RSC

View File

@@ -1,74 +1,36 @@
"""
Core Context - Framework-agnostic context management.
Application-layer context adapters.
This module provides context management that is independent of any specific
web framework. Framework-specific implementations register their context
capture functions at application initialization time.
This ensures the workflow layer remains completely decoupled from Flask
or any other web framework.
Concrete execution-context implementations live here so `dify_graph` only
depends on injected context managers rather than framework state capture.
"""
import contextvars
from collections.abc import Callable
from dify_graph.context.execution_context import (
from context.execution_context import (
AppContext,
ContextProviderNotFoundError,
ExecutionContext,
ExecutionContextBuilder,
IExecutionContext,
NullAppContext,
capture_current_context,
read_context,
register_context,
register_context_capturer,
reset_context_provider,
)
# Global capturer function - set by framework-specific modules
_capturer: Callable[[], IExecutionContext] | None = None
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""
Register a context capture function.
This should be called by framework-specific modules (e.g., Flask)
during application initialization.
Args:
capturer: Function that captures current context and returns IExecutionContext
"""
global _capturer
_capturer = capturer
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context.
This function uses the registered context capturer. If no capturer
is registered, it returns a minimal context with only contextvars
(suitable for non-framework environments like tests or standalone scripts).
Returns:
IExecutionContext with captured context
"""
if _capturer is None:
# No framework registered - return minimal context
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
def reset_context_provider() -> None:
"""
Reset the context capturer.
This is primarily useful for testing to ensure a clean state.
"""
global _capturer
_capturer = None
from context.models import SandboxContext
__all__ = [
"AppContext",
"ContextProviderNotFoundError",
"ExecutionContext",
"ExecutionContextBuilder",
"IExecutionContext",
"NullAppContext",
"SandboxContext",
"capture_current_context",
"read_context",
"register_context",
"register_context_capturer",
"reset_context_provider",
]

View File

@@ -1,5 +1,8 @@
"""
Execution Context - Abstracted context management for workflow execution.
Application-layer execution context adapters.
Concrete context capture lives outside `dify_graph` so the graph package only
consumes injected context managers when it needs to preserve thread-local state.
"""
import contextvars
@@ -16,33 +19,33 @@ class AppContext(ABC):
"""
Abstract application context interface.
This abstraction allows workflow execution to work with or without Flask
by providing a common interface for application context management.
Application adapters can implement this to restore framework-specific state
such as Flask app context around worker execution.
"""
@abstractmethod
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
pass
raise NotImplementedError
@abstractmethod
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name (e.g., 'db', 'cache')."""
pass
"""Get application extension by name."""
raise NotImplementedError
@abstractmethod
def enter(self) -> AbstractContextManager[None]:
"""Enter the application context."""
pass
raise NotImplementedError
@runtime_checkable
class IExecutionContext(Protocol):
"""
Protocol for execution context.
Protocol for enterable execution context objects.
This protocol defines the interface that all execution contexts must implement,
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
Concrete implementations may carry extra framework state, but callers only
depend on standard context-manager behavior plus optional user metadata.
"""
def __enter__(self) -> "IExecutionContext":
@@ -62,14 +65,10 @@ class IExecutionContext(Protocol):
@final
class ExecutionContext:
"""
Execution context for workflow execution in worker threads.
Generic execution context used by application-layer adapters.
This class encapsulates all context needed for workflow execution:
- Application context (Flask app or standalone)
- Context variables for Python contextvars
- User information (optional)
It is designed to be serializable and passable to worker threads.
It restores captured `contextvars` and optionally enters an application
context before the worker executes graph logic.
"""
def __init__(
@@ -78,14 +77,6 @@ class ExecutionContext:
context_vars: contextvars.Context | None = None,
user: Any = None,
) -> None:
"""
Initialize execution context.
Args:
app_context: Application context (Flask or standalone)
context_vars: Python contextvars to preserve
user: User object (optional)
"""
self._app_context = app_context
self._context_vars = context_vars
self._user = user
@@ -98,27 +89,21 @@ class ExecutionContext:
@property
def context_vars(self) -> contextvars.Context | None:
"""Get context variables."""
"""Get captured context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get user object."""
"""Get captured user object."""
return self._user
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""
Enter this execution context.
This is a convenience method that creates a context manager.
"""
# Restore context variables if provided
"""Enter this execution context."""
if self._context_vars:
for var, val in self._context_vars.items():
var.set(val)
# Enter app context if available
if self._app_context is not None:
with self._app_context.enter():
yield
@@ -141,18 +126,10 @@ class ExecutionContext:
class NullAppContext(AppContext):
"""
Null implementation of AppContext for non-Flask environments.
This is used when running without Flask (e.g., in tests or standalone mode).
Null application context for non-framework environments.
"""
def __init__(self, config: dict[str, Any] | None = None) -> None:
"""
Initialize null app context.
Args:
config: Optional configuration dictionary
"""
self._config = config or {}
self._extensions: dict[str, Any] = {}
@@ -165,7 +142,7 @@ class NullAppContext(AppContext):
return self._extensions.get(name)
def set_extension(self, name: str, extension: Any) -> None:
"""Set extension by name."""
"""Register an extension for tests or standalone execution."""
self._extensions[name] = extension
@contextmanager
@@ -176,9 +153,7 @@ class NullAppContext(AppContext):
class ExecutionContextBuilder:
"""
Builder for creating ExecutionContext instances.
This provides a fluent API for building execution contexts.
Builder for creating `ExecutionContext` instances.
"""
def __init__(self) -> None:
@@ -211,63 +186,42 @@ class ExecutionContextBuilder:
_capturer: Callable[[], IExecutionContext] | None = None
# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
# Key mapping:
# (name, tenant_id) -> provider
# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
# - tenant_id: tenant identifier string
# Value:
# provider: Callable[[], BaseModel] returning the typed context value
# Type-safety note:
# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
"""Raised when a tenant-scoped context provider is missing."""
pass
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""Register a single enterable execution context capturer (e.g., Flask)."""
"""Register an enterable execution context capturer."""
global _capturer
_capturer = capturer
def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
"""Register a tenant-specific provider for a named context.
Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
Consider adding a typed wrapper for this registration in your feature module.
"""
"""Register a tenant-specific provider for a named context."""
_tenant_context_providers[(name, tenant_id)] = provider
def read_context(name: str, *, tenant_id: str) -> BaseModel:
"""
Read a context value for a specific tenant.
Raises KeyError if the provider for (name, tenant_id) is not registered.
"""
prov = _tenant_context_providers.get((name, tenant_id))
if prov is None:
"""Read a context value for a specific tenant."""
provider = _tenant_context_providers.get((name, tenant_id))
if provider is None:
raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
return prov()
return provider()
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context from the calling environment.
If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
context with NullAppContext + copy of current contextvars.
If no framework adapter is registered, return a minimal context that only
restores `contextvars`.
"""
if _capturer is None:
return ExecutionContext(
@@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext:
def reset_context_provider() -> None:
"""Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
"""Reset the capturer and tenant-scoped providers."""
global _capturer
_capturer = None
_tenant_context_providers.clear()
__all__ = [
"AppContext",
"ContextProviderNotFoundError",
"ExecutionContext",
"ExecutionContextBuilder",
"IExecutionContext",
"NullAppContext",
"capture_current_context",
"read_context",
"register_context",
"register_context_capturer",
"reset_context_provider",
]

View File

@@ -10,11 +10,7 @@ from typing import Any, final
from flask import Flask, current_app, g
from dify_graph.context import register_context_capturer
from dify_graph.context.execution_context import (
AppContext,
IExecutionContext,
)
from context.execution_context import AppContext, IExecutionContext, register_context_capturer
@final

View File

@@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
@@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)

View File

@@ -88,6 +88,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@@ -127,6 +128,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
except Exception:
continue

View File

@@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.file_access import DatabaseFileAccessController
from core.helper.trace_id_helper import get_external_trace_id
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
@@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
@@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
return file_objs

View File

@@ -15,7 +15,8 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.app.file_access import DatabaseFileAccessController
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.file import helpers as file_helpers
from dify_graph.variables.segment_group import SegmentGroup
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
@@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -389,13 +391,21 @@ class VariableApi(Resource):
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@@ -12,6 +12,7 @@ from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from extensions.ext_database import db
@@ -496,6 +497,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
form_tokens_by_form_id = _load_form_tokens_by_form_id(
[reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)]
)
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
@@ -514,7 +518,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
"pause_type": {
"type": "human_input",
"form_id": reason.form_id,
"backstage_input_url": _build_backstage_input_url(reason.form_token),
"backstage_input_url": _build_backstage_input_url(
form_tokens_by_form_id.get(reason.form_id)
),
},
}
)

View File

@@ -25,7 +25,7 @@ from controllers.console.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
@@ -332,7 +332,7 @@ class DatasetListApi(Resource):
)
# check embedding setting
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -446,7 +446,7 @@ class DatasetApi(Resource):
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)

View File

@@ -454,7 +454,7 @@ class DatasetInitApi(Resource):
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=knowledge_config.embedding_model_provider,

View File

@@ -283,7 +283,7 @@ class DatasetDocumentSegmentApi(Resource):
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -336,7 +336,7 @@ class DatasetDocumentSegmentAddApi(Resource):
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -387,7 +387,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -572,7 +572,7 @@ class ChildChunkAddApi(Resource):
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,

View File

@@ -21,7 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.app.file_access import DatabaseFileAccessController
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
@@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
def _create_pagination_parser():
@@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource):
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=normalized_model_type,
model=args.model,
)
return jsonable_encoder(

View File

@@ -70,22 +70,25 @@ class ToolFileApi(Resource):
except Exception:
raise UnsupportedFileTypeError()
mime_type = tool_file.mime_type
filename = tool_file.filename
response = Response(
stream,
mimetype=tool_file.mimetype,
mimetype=mime_type,
direct_passthrough=True,
headers={},
)
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args.as_attachment:
encoded_filename = quote(tool_file.name)
if args.as_attachment and filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
enforce_download_for_html(
response,
mime_type=tool_file.mimetype,
filename=tool_file.name,
mime_type=mime_type,
filename=filename,
extension=extension,
)

View File

@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
import services
from core.tools.signature import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file.helpers import verify_plugin_file_signature
from fields.file_fields import FileResponse
from ..common.errors import (

View File

@@ -28,7 +28,7 @@ from core.plugin.entities.request import (
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.file.helpers import get_signed_file_url_for_plugin
from core.tools.signature import get_signed_file_url_for_plugin
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import length_prefixed_response
from models import Account, Tenant

View File

@@ -14,7 +14,7 @@ from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
)
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from fields.dataset_fields import dataset_detail_fields
@@ -140,10 +140,10 @@ class DatasetListApi(DatasetApiResource):
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
)
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
provider_manager = create_plugin_provider_manager(tenant_id=cid)
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -259,10 +259,10 @@ class DatasetApi(DatasetApiResource):
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
provider_manager = create_plugin_provider_manager(tenant_id=cid)
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)

View File

@@ -106,7 +106,7 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -160,7 +160,7 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -266,7 +266,7 @@ class DatasetSegmentApi(DatasetApiResource):
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -361,7 +361,7 @@ class ChildChunkApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,

View File

@@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ModelConfigWithCredentialsEntity,
)
from core.app.file_access import DatabaseFileAccessController
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -46,6 +47,7 @@ from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class BaseAgentRunner(AppRunner):
@@ -138,6 +140,7 @@ class BaseAgentRunner(AppRunner):
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
user_id=self.user_id,
invoke_from=self.application_generate_entity.invoke_from,
)
assert tool_entity.entity.description
@@ -524,7 +527,10 @@ class BaseAgentRunner(AppRunner):
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
message_files=files,
tenant_id=self.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
if not file_objs:
return UserPromptMessage(content=message.query)

View File

@@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)

View File

@@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)

View File

@@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -21,7 +21,7 @@ class ModelConfigConverter:
"""
model_config = app_config.model
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
)

View File

@@ -2,9 +2,8 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
@@ -54,9 +53,12 @@ class ModelConfigManager:
if not isinstance(config["model"], dict):
raise ValueError("model must be of object type")
# Keep provider discovery and provider-backed model listing on the same
# request-scoped runtime so caller scope and provider caches stay aligned.
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
# model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers()
provider_entities = assembly.model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
@@ -71,8 +73,7 @@ class ModelConfigManager:
if "name" not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
models = assembly.provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM
)

View File

@@ -24,6 +24,7 @@ from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
@@ -34,13 +35,9 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from dify_graph.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.runtime import GraphRuntimeState
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
@@ -150,85 +147,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
@@ -460,94 +459,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = conversation is None
with self._bind_file_access_scope(
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
invoke_from=invoke_from,
):
is_first_conversation = conversation is None
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# get conversation dialogue count
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# get conversation dialogue count
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread.start()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
# message_ = session.get(Message, message.id)
# assert message_ is not None
# message = message_
# db.session.refresh(workflow)
# db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
worker_thread.start()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
db.session.close()
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@@ -25,14 +25,19 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import get_default_root_node_id
from core.workflow.system_variables import (
build_bootstrap_variables,
build_system_variables,
system_variables_to_mapping,
)
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.enums import WorkflowType
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import VariableLoader
from dify_graph.variables.variables import Variable
from extensions.ext_database import db
@@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
system_inputs = SystemVariable(
system_inputs = build_system_variables(
query=self.application_generate_entity.query,
files=self.application_generate_entity.files,
conversation_id=self.conversation.id,
@@ -150,7 +155,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.application_generate_entity.inputs = new_inputs
self.application_generate_entity.query = new_query
system_inputs.query = new_query
system_inputs = build_system_variables(
system_variables_to_mapping(system_inputs),
query=new_query,
)
# annotation reply
if self.handle_annotation_reply(
@@ -166,14 +174,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Create a variable pool.
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=new_inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(
system_variables=system_inputs,
environment_variables=self._workflow.environment_variables,
conversation_variables=conversation_variables,
),
)
root_node_id = get_default_root_node_id(self._workflow.graph_dict)
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs)
# init graph
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
@@ -185,6 +196,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=root_node_id,
)
db.session.close()

View File

@@ -14,6 +14,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
@@ -65,14 +66,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import build_system_variables
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.runtime import GraphRuntimeState
from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
@@ -117,7 +118,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_system_variables = SystemVariable(
self._workflow_system_variables = build_system_variables(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
@@ -741,8 +742,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
tenant_id=self._workflow_tenant_id,
workflow_execution_id=self._workflow_run_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
form = form_repository.get_form(node_id)
if form is None:
return None
return form.id
@@ -933,21 +935,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
message_files = [
MessageFile(
message_id=message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to=MessageFileBelongsTo.ASSISTANT,
upload_file_id=file["related_id"],
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER,
created_by=message.from_account_id or message.from_end_user_id or "",
message_files: list[MessageFile] = []
for file in self._recorded_files:
reference = file.get("reference") or file.get("related_id")
message_files.append(
MessageFile(
message_id=message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to=MessageFileBelongsTo.ASSISTANT,
upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None),
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER,
created_by=message.from_account_id or message.from_end_user_id or "",
)
)
for file in self._recorded_files
]
session.add_all(message_files)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
@@ -1003,13 +1007,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
return message
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)
saver = self._draft_var_saver_factory(
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

View File

@@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(
override_model_config_dict or app_model_config.to_dict()
)
else:
file_objs = []
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager,
)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@@ -1,17 +1,20 @@
from collections.abc import Generator, Mapping, Sequence
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.enums import NodeType
from dify_graph.file import File, FileUploadConfig
from dify_graph.repositories.draft_variable_repository import (
from core.app.apps.draft_variable_saver import (
DraftVariableSaver,
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
from dify_graph.enums import NodeType
from dify_graph.file import File, FileUploadConfig
from dify_graph.variables.input_entities import VariableEntityType
from extensions.ext_database import db
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
@@ -21,7 +24,66 @@ if TYPE_CHECKING:
from dify_graph.variables.input_entities import VariableEntity
@final
class _DebuggerDraftVariableSaver:
"""Adapter that binds SQLAlchemy session setup outside the saver port."""
def __init__(
self,
*,
account: Account,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> None:
self._account = account
self._app_id = app_id
self._node_id = node_id
self._node_type = node_type
self._node_execution_id = node_execution_id
self._enclosing_node_id = enclosing_node_id
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
with Session(db.engine) as session, session.begin():
DraftVariableSaverImpl(
session=session,
app_id=self._app_id,
node_id=self._node_id,
node_type=self._node_type,
node_execution_id=self._node_execution_id,
enclosing_node_id=self._enclosing_node_id,
user=self._account,
).save(process_data, outputs)
class BaseAppGenerator:
_file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController()
@staticmethod
def _bind_file_access_scope(
*,
tenant_id: str,
user: Account | EndUser,
invoke_from: InvokeFrom,
) -> AbstractContextManager[None]:
"""Bind request-scoped file ownership markers for downstream file lookups."""
user_id = getattr(user, "id", None)
if not isinstance(user_id, str) or not user_id:
return nullcontext()
user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER
return bind_file_access_scope(
FileAccessScope(
tenant_id=tenant_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
)
)
def _prepare_user_inputs(
self,
*,
@@ -50,6 +112,7 @@ class BaseAppGenerator:
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
strict_type_validation=strict_type_validation,
access_controller=self._file_access_controller,
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
@@ -64,6 +127,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
access_controller=self._file_access_controller,
)
for k, v in user_inputs.items()
if isinstance(v, list)
@@ -226,32 +290,30 @@ class BaseAppGenerator:
assert isinstance(account, Account)
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return DraftVariableSaverImpl(
session=session,
return _DebuggerDraftVariableSaver(
account=account,
app_id=app_id,
node_id=node_id,
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
)
else:
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
_ = app_id, node_id, node_type, node_execution_id, enclosing_node_id
return NoopDraftVariableSaver()
return draft_var_saver_factory

View File

@@ -61,27 +61,30 @@ class AppQueueManager(ABC):
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
start_time = time.time()
last_ping_time: int | float = 0
while True:
try:
message = self._q.get(timeout=1)
if message is None:
break
try:
while True:
try:
message = self._q.get(timeout=1)
if message is None:
break
yield message
except queue.Empty:
continue
finally:
elapsed_time = time.time() - start_time
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
)
yield message
except queue.Empty:
continue
finally:
elapsed_time = time.time() - start_time
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
)
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10
finally:
self._graph_runtime_state = None # Release reference once consumers finish or close the generator.
def stop_listen(self):
"""
@@ -90,7 +93,6 @@ class AppQueueManager(ABC):
"""
self._clear_task_belong_cache()
self._q.put(None)
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None:
"""

View File

@@ -1,3 +1,4 @@
import contextvars
import logging
import threading
import uuid
@@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(
override_model_config_dict or app_model_config.to_dict()
)
else:
file_objs = []
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=streaming,
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=streaming,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
context = contextvars.copy_context()
worker_thread.start()
# new thread with request context
@copy_current_request_context
def worker_with_context():
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
worker_thread = threading.Thread(target=worker_with_context)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner):
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.workflow.system_variables import SystemVariableKey, get_system_text
from dify_graph.runtime import GraphRuntimeState
if TYPE_CHECKING:
@@ -30,10 +31,10 @@ class GraphRuntimeStateSupport:
return self._resolve_graph_runtime_state(graph_runtime_state)
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
system_variables = graph_runtime_state.variable_pool.system_variables
if not system_variables or not system_variables.workflow_execution_id:
workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID)
if not workflow_run_id:
raise ValueError("workflow_execution_id missing from runtime state")
return str(system_variables.workflow_execution_id)
return workflow_run_id
def _resolve_graph_runtime_state(
self,

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Mapping, Sequence
@@ -50,20 +51,21 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import (
BuiltinNodeTypes,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import FILE_MODEL_IDENTITY, File
from dify_graph.runtime import GraphRuntimeState
from dify_graph.system_variable import SystemVariable
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
from dify_graph.variables.variables import Variable
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@@ -111,11 +113,11 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
system_variables: SystemVariable,
system_variables: Sequence[Variable],
):
self._application_generate_entity = application_generate_entity
self._user = user
self._system_variables = system_variables
self._system_variables = system_variables_to_mapping(system_variables)
self._workflow_inputs = self._prepare_workflow_inputs()
# Disable truncation for SERVICE_API calls to keep backward compatibility.
@@ -133,7 +135,7 @@ class WorkflowResponseConverter:
# ------------------------------------------------------------------
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = dict(self._application_generate_entity.inputs)
for field_name, value in self._system_variables.to_dict().items():
for field_name, value in self._system_variables.items():
# TODO(@future-refactor): store system variables separately from user inputs so we don't
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
if field_name == SystemVariableKey.CONVERSATION_ID:
@@ -318,13 +320,23 @@ class WorkflowResponseConverter:
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_token_by_form_id: dict[str, str] = {}
if human_input_form_ids:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
stmt = select(
HumanInputForm.id,
HumanInputForm.expiration_time,
HumanInputForm.form_definition,
).where(HumanInputForm.id.in_(human_input_form_ids))
with Session(bind=db.engine) as session:
for form_id, expiration_time in session.execute(stmt):
for form_id, expiration_time, form_definition in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = expiration_time
try:
definition_payload = json.loads(form_definition) if form_definition else {}
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
responses: list[StreamResponse] = []
@@ -344,8 +356,8 @@ class WorkflowResponseConverter:
form_content=reason.form_content,
inputs=reason.inputs,
actions=reason.actions,
display_in_ui=reason.display_in_ui,
form_token=reason.form_token,
display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False),
form_token=form_token_by_form_id.get(reason.form_id),
resolved_default_values=reason.resolved_default_values,
expiration_time=int(expiration_time.timestamp()),
),

View File

@@ -1,3 +1,4 @@
import contextvars
import logging
import threading
import uuid
@@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(
override_model_config_dict or app_model_config.to_dict()
)
else:
file_objs = []
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={},
trace_manager=trace_manager,
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={},
trace_manager=trace_manager,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
context = contextvars.copy_context()
worker_thread.start()
# new thread with request context
@copy_current_request_context
def worker_with_context():
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
worker_thread = threading.Thread(target=worker_with_context)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
@@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict
# parse files
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
# parse files
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
access_controller=self._file_access_controller,
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={},
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={},
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
context = contextvars.copy_context()
worker_thread.start()
# new thread with request context
@copy_current_request_context
def worker_with_context():
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)
worker_thread = threading.Thread(target=worker_with_context)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner):
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result

View File

@@ -4,31 +4,30 @@ import abc
from collections.abc import Mapping
from typing import Any, Protocol
from sqlalchemy.orm import Session
from dify_graph.enums import NodeType
class DraftVariableSaver(Protocol):
@abc.abstractmethod
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
pass
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
"""Persist node draft variables for a completed execution."""
raise NotImplementedError
class DraftVariableSaverFactory(Protocol):
@abc.abstractmethod
def __call__(
self,
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
pass
"""Build a saver bound to a concrete node execution."""
raise NotImplementedError
class NoopDraftVariableSaver(DraftVariableSaver):
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
pass
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
return None

View File

@@ -28,6 +28,7 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file_reference import resolve_file_record_id
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
@@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
transfer_method=file.transfer_method,
belongs_to=MessageFileBelongsTo.USER,
url=file.remote_url,
upload_file_id=file.related_id,
upload_file_id=resolve_file_record_id(file.reference),
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)

View File

@@ -18,6 +18,7 @@ import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
@@ -34,11 +35,12 @@ from core.datasource.entities.datasource_entities import (
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.factory import (
DifyCoreRepositoryFactory,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository,
)
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from libs.flask_utils import preserve_flask_contexts

View File

@@ -12,16 +12,16 @@ from core.app.entities.app_invoke_entities import (
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.enums import WorkflowType
from dify_graph.graph import Graph
from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import VariableLoader
from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from extensions.ext_database import db
@@ -112,7 +112,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
system_inputs = build_system_variables(
files=files,
user_id=user_id,
app_id=app_config.app_id,
@@ -142,19 +142,25 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(
system_variables=system_inputs,
environment_variables=workflow.environment_variables,
rag_pipeline_variables=rag_pipeline_variables,
),
)
root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id(
workflow.graph_dict
)
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_rag_pipeline_graph(
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
start_node_id=root_node_id,
workflow=workflow,
user_from=user_from,
invoke_from=invoke_from,

View File

@@ -17,6 +17,7 @@ from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
@@ -30,11 +31,9 @@ from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.runtime import GraphRuntimeState
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
@@ -129,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
inputs: Mapping[str, Any] = args["inputs"]
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
access_controller=self._file_access_controller,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
inputs: Mapping[str, Any] = args["inputs"]
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(
self,
@@ -292,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
with self._bind_file_access_scope(
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
invoke_from=invoke_from,
):
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# release database connection, because the following new thread operations may take a long time
db.session.close()
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
# release database connection, because the following new thread operations may take a long time
db.session.close()
worker_thread.start()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,

View File

@@ -8,14 +8,15 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.enums import WorkflowType
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import VariableLoader
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
@@ -96,7 +97,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
# Create a variable pool.
system_inputs = SystemVariable(
system_inputs = build_system_variables(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
@@ -104,12 +105,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
conversation_variables=[],
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(
system_variables=system_inputs,
environment_variables=self._workflow.environment_variables,
),
)
root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict)
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph = self._init_graph(
@@ -120,7 +125,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
root_node_id=root_node_id,
)
# RUN WORKFLOW

View File

@@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
@@ -55,11 +56,10 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.system_variables import build_system_variables
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.runtime import GraphRuntimeState
from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
from models import Account
from models.enums import CreatorUserRole
@@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
self._workflow = workflow
self._workflow_system_variables = SystemVariable(
self._workflow_system_variables = build_system_variables(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
@@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
return response
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)
saver = self._draft_var_saver_factory(
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

View File

@@ -34,7 +34,16 @@ from core.app.entities.queue_entities import (
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import (
build_bootstrap_variables,
default_system_variables,
get_node_creation_preload_selectors,
inject_default_system_variable_mappings,
preload_node_creation_variables,
)
from core.workflow.variable_pool_initializer import add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
@@ -68,7 +77,6 @@ from dify_graph.graph_events import (
)
from dify_graph.graph_events.graph import GraphRunAbortedEvent
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from models.workflow import Workflow
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
@@ -173,14 +181,15 @@ class WorkflowBasedAppRunner:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(
system_variables=default_system_variables(),
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
@@ -272,6 +281,8 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
workflow_id=workflow.id,
@@ -291,26 +302,15 @@ class WorkflowBasedAppRunner:
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(
graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
for node in typed_node_configs:
if node["id"] == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)
@@ -319,12 +319,31 @@ class WorkflowBasedAppRunner:
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
preload_node_creation_variables(
variable_loader=self._variable_loader,
variable_pool=variable_pool,
selectors=[
selector
for node_config in typed_node_configs
for selector in get_node_creation_preload_selectors(
node_type=node_config["data"].type,
node_data=node_config["data"],
)
],
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=target_node_config
)
except NotImplementedError:
variable_mapping = {}
variable_mapping = inject_default_system_variable_mappings(
node_id=target_node_config["id"],
node_type=node_type,
node_data=target_node_config["data"],
variable_mapping=variable_mapping,
)
load_into_variable_pool(
variable_loader=self._variable_loader,
@@ -340,6 +359,14 @@ class WorkflowBasedAppRunner:
tenant_id=workflow.tenant_id,
)
# init graph after constructor-time context has been loaded
graph = Graph.init(
graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
)
if not graph:
raise ValueError("graph not found in workflow")
return graph, variable_pool
@staticmethod
@@ -408,7 +435,11 @@ class WorkflowBasedAppRunner:
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=inputs,
outputs=node_run_result.outputs,
)
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
@@ -448,7 +479,11 @@ class WorkflowBasedAppRunner:
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=inputs,
outputs=node_run_result.outputs,
)
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
@@ -466,6 +501,11 @@ class WorkflowBasedAppRunner:
)
)
elif isinstance(event, NodeRunFailedEvent):
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=event.node_run_result.inputs,
outputs=event.node_run_result.outputs,
)
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
@@ -475,7 +515,7 @@ class WorkflowBasedAppRunner:
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
outputs=outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
@@ -483,6 +523,11 @@ class WorkflowBasedAppRunner:
)
)
elif isinstance(event, NodeRunExceptionEvent):
outputs = project_node_outputs_for_workflow_run(
node_type=event.node_type,
inputs=event.node_run_result.inputs,
outputs=event.node_run_result.outputs,
)
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
@@ -492,7 +537,7 @@ class WorkflowBasedAppRunner:
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
outputs=outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.file import File, FileUploadConfig
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
@@ -15,6 +14,9 @@ if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
DIFY_RUN_CONTEXT_KEY = "_dify"
class UserFrom(StrEnum):
ACCOUNT = "account"
END_USER = "end-user"

View File

@@ -0,0 +1,11 @@
from .controller import DatabaseFileAccessController
from .protocols import FileAccessControllerProtocol
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
__all__ = [
"DatabaseFileAccessController",
"FileAccessControllerProtocol",
"FileAccessScope",
"bind_file_access_scope",
"get_current_file_access_scope",
]

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from collections.abc import Callable
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
from .protocols import FileAccessControllerProtocol
from .scope import FileAccessScope, get_current_file_access_scope
class DatabaseFileAccessController(FileAccessControllerProtocol):
"""Workflow-layer authorization helper for database-backed file lookups.
Tenant scoping remains mandatory. When the current execution belongs to an
end user, the lookup is additionally constrained to that end user's file
ownership markers.
"""
_scope_getter: Callable[[], FileAccessScope | None]
def __init__(
self,
*,
scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope,
) -> None:
self._scope_getter = scope_getter
def current_scope(self) -> FileAccessScope | None:
return self._scope_getter()
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[UploadFile]]:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return stmt
scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id)
if not resolved_scope.requires_user_ownership:
return scoped_stmt
return scoped_stmt.where(
UploadFile.created_by_role == CreatorUserRole.END_USER,
UploadFile.created_by == resolved_scope.user_id,
)
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[ToolFile]]:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return stmt
scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id)
if not resolved_scope.requires_user_ownership:
return scoped_stmt
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
def get_upload_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> UploadFile | None:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return session.get(UploadFile, file_id)
stmt = self.apply_upload_file_filters(
select(UploadFile).where(UploadFile.id == file_id),
scope=resolved_scope,
)
return session.scalar(stmt)
def get_tool_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> ToolFile | None:
resolved_scope = scope or self.current_scope()
if resolved_scope is None:
return session.get(ToolFile, file_id)
stmt = self.apply_tool_file_filters(
select(ToolFile).where(ToolFile.id == file_id),
scope=resolved_scope,
)
return session.scalar(stmt)

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from typing import Protocol
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from models import ToolFile, UploadFile
from .scope import FileAccessScope
class FileAccessControllerProtocol(Protocol):
"""Contract for applying access rules to file lookups.
Implementations translate an optional execution scope into query constraints
and authorized record retrieval. The contract is intentionally limited to
ownership and tenancy rules for workflow-layer file access.
"""
def current_scope(self) -> FileAccessScope | None:
"""Return the scope active for the current execution, if one exists.
Callers use this to decide whether embedded file metadata may be trusted
or whether a fresh authorized lookup is required.
"""
...
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[UploadFile]]:
"""Return an upload-file query constrained by the supplied access scope.
The returned statement must preserve the caller's existing predicates and
append only access-control conditions.
"""
...
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
*,
scope: FileAccessScope | None = None,
) -> Select[tuple[ToolFile]]:
"""Return a tool-file query constrained by the supplied access scope.
The returned statement must preserve the caller's existing predicates and
append only access-control conditions.
"""
...
def get_upload_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> UploadFile | None:
"""Load one authorized upload-file record for the given identifier.
Returns ``None`` when the file does not exist or when the scope does not
permit access to that record.
"""
...
def get_tool_file(
self,
*,
session: Session,
file_id: str,
scope: FileAccessScope | None = None,
) -> ToolFile | None:
"""Load one authorized tool-file record for the given identifier.
Returns ``None`` when the file does not exist or when the scope does not
permit access to that record.
"""
...

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
"current_file_access_scope",
default=None,
)
@dataclass(frozen=True, slots=True)
class FileAccessScope:
"""Request-scoped ownership context used by workflow-layer file lookups."""
tenant_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
@property
def requires_user_ownership(self) -> bool:
return self.user_from == UserFrom.END_USER
def get_current_file_access_scope() -> FileAccessScope | None:
return _current_file_access_scope.get()
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
token = _current_file_access_scope.set(scope)
try:
yield
finally:
_current_file_access_scope.reset(token)

View File

@@ -1,12 +1,19 @@
"""
Persist conversation-scoped variable updates emitted by the graph engine.
The graph package emits generic variable update events and stays unaware of
conversation identity or storage concerns. This layer lives in the application
core, listens to those generic events, and persists only the `conversation.*`
scope updates that matter to chat applications.
"""
import logging
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
from dify_graph.enums import BuiltinNodeTypes
from core.workflow.system_variables import SystemVariableKey, get_system_text
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.variables import VariableBase
from dify_graph.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent
from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@@ -20,41 +27,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
pass
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
if not isinstance(event, NodeRunVariableUpdatedEvent):
return
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
if not updated_variables:
selector = event.variable.selector
if len(selector) < 2:
logger.warning("Conversation variable selector invalid. selector=%s", selector)
return
conversation_id = self.graph_runtime_state.system_variable.conversation_id
conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID)
if conversation_id is None:
return
updated_any = False
for item in updated_variables:
selector = item.selector
if len(selector) < 2:
logger.warning("Conversation variable selector invalid. selector=%s", selector)
continue
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
)
continue
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
updated_any = True
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
return
if updated_any:
self._conversation_variable_updater.flush()
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@@ -6,6 +6,7 @@ from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.system_variables import SystemVariableKey, get_system_text
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events.base import GraphEngineEvent
from dify_graph.graph_events.graph import GraphRunPausedEvent
@@ -119,7 +120,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
workflow_run_id = get_system_text(
self.graph_runtime_state.variable_pool,
SystemVariableKey.WORKFLOW_EXECUTION_ID,
)
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(

View File

@@ -5,6 +5,7 @@ from typing import Any, ClassVar
from pydantic import TypeAdapter
from core.db.session_factory import session_factory
from core.workflow.system_variables import SystemVariableKey, get_system_text
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events.base import GraphEngineEvent
from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@@ -59,7 +60,10 @@ class TriggerPostLayer(GraphEngineLayer):
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
workflow_run_id = get_system_text(
self.graph_runtime_state.variable_pool,
SystemVariableKey.WORKFLOW_EXECUTION_ID,
)
assert workflow_run_id, "Workflow run id is not set"
total_tokens = self.graph_runtime_state.total_tokens

View File

@@ -2,23 +2,34 @@ from __future__ import annotations
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.nodes.llm.entities import ModelConfig
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
def __init__(
self,
*,
run_context: DifyRunContext,
provider_manager: ProviderManager | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
if provider_manager is None:
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
@@ -42,9 +53,21 @@ class DifyModelFactory:
tenant_id: str
model_manager: ModelManager
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
def __init__(
self,
*,
run_context: DifyRunContext,
model_manager: ModelManager | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
if model_manager is None:
model_manager = ModelManager(
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
)
self.model_manager = model_manager
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(
@@ -55,18 +78,42 @@ class DifyModelFactory:
)
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
return (
DifyCredentialsProvider(tenant_id=tenant_id),
DifyModelFactory(tenant_id=tenant_id),
def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]:
"""Create LLM access adapters that share the same tenant-bound manager graph."""
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
DifyModelFactory(run_context=run_context, model_manager=model_manager),
)
def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Split node-level completion params into provider parameters and stop sequences.
Workflow LLM-compatible nodes still consume runtime invocation settings from
``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the
``ModelInstance`` view and the returned config entity aligned here so callers
do not need to duplicate normalization logic.
"""
normalized_parameters = dict(completion_params)
stop = normalized_parameters.pop("stop", [])
if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop):
stop = []
return normalized_parameters, stop
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_factory: DifyModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
@@ -80,22 +127,18 @@ def fetch_model_config(
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
if model_schema is None:
raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.")
parameters, stop = _normalize_completion_params(node_data_model.completion_params)
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.parameters = parameters
model_instance.stop = tuple(stop)
return model_instance, ModelConfigWithCredentialsEntity(
@@ -103,8 +146,8 @@ def fetch_model_config(
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=completion_params,
parameters=parameters,
stop=stop,
provider_model_bundle=provider_model_bundle,
)

View File

@@ -1,33 +1,42 @@
from __future__ import annotations
import base64
import hashlib
import hmac
import os
import time
import urllib.parse
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from dify_graph.file.enums import FileTransferMethod
from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from dify_graph.file.runtime import set_workflow_file_runtime
from extensions.ext_storage import storage
if TYPE_CHECKING:
from dify_graph.file.models import File
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
"""Production runtime wiring for ``dify_graph.file``."""
"""Production runtime wiring for ``dify_graph.file``.
@property
def files_url(self) -> str:
return dify_config.FILES_URL
Opaque file references are resolved back to canonical database records before
URLs are signed or storage keys are used. When a request-scoped file access
scope is present, those lookups additionally enforce tenant and end-user
ownership filters.
"""
@property
def internal_files_url(self) -> str | None:
return dify_config.INTERNAL_FILES_URL
_file_access_controller: FileAccessControllerProtocol
@property
def secret_key(self) -> str:
return dify_config.SECRET_KEY
@property
def files_access_timeout(self) -> int:
return dify_config.FILES_ACCESS_TIMEOUT
def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None:
self._file_access_controller = file_access_controller
@property
def multimodal_send_format(self) -> str:
@@ -39,9 +48,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
def load_file_bytes(self, *, file: File) -> bytes:
storage_key = self._resolve_storage_key(file=file)
data = storage.load(storage_key, stream=False)
if not isinstance(data, bytes):
raise ValueError(f"file {storage_key} is not a bytes object")
return data
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
return file.remote_url
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("Missing file reference")
if file.transfer_method == FileTransferMethod.LOCAL_FILE:
return self.resolve_upload_file_url(
upload_file_id=parsed_reference.record_id,
for_external=for_external,
)
if file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
if file.extension is None:
raise ValueError("Missing file extension")
self._assert_upload_file_access(upload_file_id=parsed_reference.record_id)
return sign_tool_file(
tool_file_id=parsed_reference.record_id,
extension=file.extension,
for_external=for_external,
)
if file.transfer_method == FileTransferMethod.TOOL_FILE:
if file.extension is None:
raise ValueError("Missing file extension")
return self.resolve_tool_file_url(
tool_file_id=parsed_reference.record_id,
extension=file.extension,
for_external=for_external,
)
return None
def resolve_upload_file_url(
self,
*,
upload_file_id: str,
as_attachment: bool = False,
for_external: bool = True,
) -> str:
self._assert_upload_file_access(upload_file_id=upload_file_id)
base_url = self._base_url(for_external=for_external)
url = f"{base_url}/files/{upload_file_id}/file-preview"
query = self._sign_query(payload=f"file-preview|{upload_file_id}")
if as_attachment:
query["as_attachment"] = "true"
return f"{url}?{urllib.parse.urlencode(query)}"
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._assert_tool_file_access(tool_file_id=tool_file_id)
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
def verify_preview_signature(
self,
*,
preview_kind: Literal["image", "file"],
file_id: str,
timestamp: str,
nonce: str,
sign: str,
) -> bool:
payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}"
recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest()
if sign != base64.urlsafe_b64encode(recalculated).decode():
return False
return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def _base_url(*, for_external: bool) -> str:
if for_external:
return dify_config.FILES_URL
return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
@staticmethod
def _secret_key() -> bytes:
return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
def _sign_query(self, *, payload: str) -> dict[str, str]:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest()
return {
"timestamp": timestamp,
"nonce": nonce,
"sign": base64.urlsafe_b64encode(sign).decode(),
}
def _resolve_storage_key(self, *, file: File) -> str:
parsed_reference = parse_file_reference(file.reference)
if parsed_reference is None:
raise ValueError("Missing file reference")
record_id = parsed_reference.record_id
with session_factory.create_session() as session:
if file.transfer_method in {
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
FileTransferMethod.DATASOURCE_FILE,
}:
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id)
if upload_file is None:
raise ValueError(f"Upload file {record_id} not found")
return upload_file.key
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id)
if tool_file is None:
raise ValueError(f"Tool file {record_id} not found")
return tool_file.file_key
def _assert_upload_file_access(self, *, upload_file_id: str) -> None:
if self._file_access_controller.current_scope() is None:
return
with session_factory.create_session() as session:
upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id)
if upload_file is None:
raise ValueError(f"Upload file {upload_file_id} not found")
def _assert_tool_file_access(self, *, tool_file_id: str) -> None:
if self._file_access_controller.current_scope() is None:
return
with session_factory.create_session() as session:
tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id)
if tool_file is None:
raise ValueError(f"Tool file {tool_file_id} not found")
def bind_dify_workflow_file_runtime() -> None:
set_workflow_file_runtime(DifyWorkflowFileRuntime())
set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()))

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
@@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer):
return
try:
dify_ctx = node.require_dify_context()
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,
@@ -114,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer):
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
return cast("LLMNode", node).model_instance
model_instance = cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
model_instance = cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
@@ -127,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer):
node.id,
)
return None
if isinstance(model_instance, ModelInstance):
return model_instance
raw_model_instance = getattr(model_instance, "_model_instance", None)
if isinstance(raw_model_instance, ModelInstance):
return raw_model_instance
return None

View File

@@ -17,10 +17,12 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.system_variables import SystemVariableKey
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run
from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution
from dify_graph.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
@@ -43,8 +45,6 @@ from dify_graph.graph_events import (
NodeRunSucceededEvent,
)
from dify_graph.node_events import NodeRunResult
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from libs.datetime_utils import naive_utc_now
@@ -372,10 +372,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
domain_execution.error = error
if update_outputs:
projected_outputs = project_node_outputs_for_workflow_run(
node_type=domain_execution.node_type,
inputs=node_result.inputs,
outputs=node_result.outputs,
)
domain_execution.update_from_mapping(
inputs=node_result.inputs,
process_data=node_result.process_data,
outputs=node_result.outputs,
outputs=projected_outputs,
metadata=node_result.metadata,
)

View File

@@ -25,12 +25,10 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
def _process_future(
@@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher:
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts")
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
@@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher:
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
_invoice_tts, self.msg_text, self.model_instance, self.voice
)
future_queue.put(futures_result)
break
@@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher:
if len(sentence_arr) >= min(self.max_sentence, 7):
self.max_sentence += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice)
future_queue.put(futures_result)
if isinstance(text_tmp, str):
self.msg_text = text_tmp

View File

@@ -6,6 +6,7 @@ from typing import Any, cast
from sqlalchemy import select
import contexts
from core.app.file_access import DatabaseFileAccessController
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import (
@@ -24,10 +25,11 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.file_reference import build_file_reference
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.file import File
from dify_graph.file import File, get_file_type_by_mime_type
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from factories import file_factory
@@ -36,6 +38,7 @@ from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class DatasourceManager:
@@ -279,11 +282,15 @@ class DatasourceManager:
if datasource_file is not None:
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(mime_type),
"type": get_file_type_by_mime_type(mime_type),
"transfer_method": FileTransferMethod.TOOL_FILE,
"url": url,
}
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
file_out = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
elif mtype == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
@@ -351,11 +358,10 @@ class DatasourceManager:
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,

View File

@@ -4,6 +4,7 @@ from mimetypes import guess_extension, guess_type
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file_reference import parse_file_reference
from dify_graph.file import File, FileTransferMethod, FileType
from models.tools import ToolFile
@@ -103,8 +104,14 @@ class DatasourceFileMessageTransformer:
file: File | None = meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
reference = getattr(file, "reference", None) or getattr(file, "related_id", None)
parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None
if parsed_reference is None:
raise ValueError("datasource file is missing reference")
url = cls.get_datasource_file_url(
datasource_file_id=parsed_reference.record_id,
extension=file.extension,
)
if file.type == FileType.IMAGE:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,

View File

@@ -1,10 +1,5 @@
from enum import StrEnum, auto
"""Compatibility wrapper for the runtime embedding input enum."""
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = auto()
QUERY = auto()
__all__ = ["EmbeddingInputType"]

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import json
import logging
import re
@@ -5,7 +7,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@@ -19,6 +21,7 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from dify_graph.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@@ -28,6 +31,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
)
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from dify_graph.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel):
- Load balancing configurations
- Model enablement/disablement
Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so
nested schema and model lookups reuse the caller scope that was already
resolved by the composition layer.
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
"""
@@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
_bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _(self):
@@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def bind_model_runtime(self, model_runtime: ModelRuntime) -> None:
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""
Get current credentials.
@@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = self.get_model_provider_factory()
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
@@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel):
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = self.get_model_provider_factory()
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = self.get_model_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types: list[ModelType] = []

View File

@@ -4,10 +4,10 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType
@@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(

View File

@@ -50,7 +50,10 @@ logger = logging.getLogger(__name__)
class IndexingRunner:
def __init__(self):
self.storage = storage
self.model_manager = ModelManager()
@staticmethod
def _get_model_manager(tenant_id: str) -> ModelManager:
return ModelManager.for_tenant(tenant_id=tenant_id)
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
@@ -291,20 +294,20 @@ class IndexingRunner:
raise ValueError("Dataset not found.")
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
@@ -574,7 +577,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
@@ -766,14 +769,14 @@ class IndexingRunner:
embedding_model_instance = None
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)

View File

@@ -62,7 +62,7 @@ class LLMGenerator:
prompt += query + "\n"
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -120,7 +120,7 @@ class LLMGenerator:
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -172,7 +172,7 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt_generate)]
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@@ -219,7 +219,7 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
# get model instance
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -306,7 +306,7 @@ class LLMGenerator:
remove_template_variables=False,
)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -337,7 +337,7 @@ class LLMGenerator:
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -362,7 +362,7 @@ class LLMGenerator:
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -536,7 +536,7 @@ class LLMGenerator:
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
if ERROR_MESSAGE in injected_instruction:
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
model_instance = ModelManager().get_model_instance(
model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,

View File

@@ -55,7 +55,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
@@ -70,7 +69,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
@@ -85,7 +83,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
@@ -99,7 +96,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
"""
@@ -113,7 +109,6 @@ def invoke_llm_with_structured_output(
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -143,7 +138,6 @@ def invoke_llm_with_structured_output(
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)

View File

@@ -4,6 +4,7 @@ from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.file_access import DatabaseFileAccessController
from core.model_manager import ModelInstance
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from dify_graph.file import file_manager
@@ -23,6 +24,8 @@ from models.workflow import Workflow
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
_file_access_controller = DatabaseFileAccessController()
class TokenBufferMemory:
def __init__(
@@ -85,7 +88,10 @@ class TokenBufferMemory:
# Build files directly without filtering by belongs_to
file_objs = [
file_factory.build_from_message_file(
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
message_file=message_file,
tenant_id=app_record.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
for message_file in message_files
]

View File

@@ -7,11 +7,12 @@ from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import LLMResult
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
class ModelInstance:
"""
Model instance class
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
@@ -49,6 +50,13 @@ class ModelInstance:
credentials=self.credentials,
)
def get_model_schema(self) -> AIModelEntity:
"""Return the resolved schema for the current model instance."""
model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials)
if model_schema is None:
raise ValueError(f"model schema not found for {self.model_name}")
return model_schema
@staticmethod
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str):
"""
@@ -110,7 +118,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator: ...
@@ -122,7 +129,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResult: ...
@@ -134,7 +140,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]: ...
@@ -145,7 +150,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]:
"""
@@ -156,7 +160,6 @@ class ModelInstance:
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -173,7 +176,6 @@ class ModelInstance:
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
),
)
@@ -202,13 +204,12 @@ class ModelInstance:
)
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> EmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -221,7 +222,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
),
)
@@ -229,14 +229,12 @@ class ModelInstance:
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -249,7 +247,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
@@ -279,7 +276,6 @@ class ModelInstance:
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -288,7 +284,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -303,7 +298,6 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
@@ -313,7 +307,6 @@ class ModelInstance:
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -322,7 +315,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -337,16 +329,14 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
def invoke_moderation(self, text: str) -> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
@@ -358,16 +348,14 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
text=text,
user=user,
),
)
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
def invoke_speech2text(self, file: IO[bytes]) -> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
@@ -379,18 +367,15 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
file=file,
user=user,
),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param voice: model timbre
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@@ -402,8 +387,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
),
)
@@ -477,10 +460,20 @@ class ModelInstance:
class ModelManager:
def __init__(self):
self._provider_manager = ProviderManager()
def __init__(self, provider_manager: ProviderManager):
self._provider_manager = provider_manager
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
def get_model_instance(
self,
tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
@@ -496,7 +489,8 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@@ -50,7 +50,7 @@ class OpenAIModeration(Moderation):
def _is_violated(self, inputs: dict):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
)

View File

@@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance):
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
return workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
def build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata

View File

@@ -271,8 +271,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
try:

View File

@@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance):
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:

View File

@@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance):
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:

View File

@@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance):
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:

View File

@@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance):
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id)
return list(executions)
except Exception:

View File

@@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance):
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
# rearrange workflow_node_executions by starting time

View File

@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
@staticmethod
def _get_bound_model_instance(
*,
tenant_id: str,
user_id: str | None,
provider: str,
model_type: ModelType,
model: str,
):
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider,
model_type=model_type,
model=model,
)
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
)
if isinstance(response, Generator):
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
response = model_instance.invoke_text_embedding(texts=payload.texts)
return response
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
for chunk in response:
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
response = model_instance.invoke_speech2text(file=temp)
return {
"result": response,
@@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
response = model_instance.invoke_moderation(text=payload.text)
return {
"result": response,
}
@classmethod
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int:
"""
get system model max tokens
"""
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id)
@classmethod
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
"""
get prompt tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
return ModelInvocationUtils.calculate_tokens(
tenant_id=tenant_id,
prompt_messages=prompt_messages,
user_id=user_id,
)
@classmethod
def invoke_system_model(
@@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tool_type=ToolProviderType.PLUGIN,
tool_name="plugin",
prompt_messages=prompt_messages,
caller_user_id=user_id,
)
@classmethod
@@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke summary
"""
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id)
content = payload.text
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
@@ -325,6 +337,7 @@ Here is the extra instruction you need to follow:
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=content)],
user_id=user_id,
)
< max_tokens * 0.6
):
@@ -337,6 +350,7 @@ Here is the extra instruction you need to follow:
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
UserPromptMessage(content=content),
],
user_id=user_id,
)
def summarize(content: str) -> str:
@@ -394,6 +408,7 @@ Here is the extra instruction you need to follow:
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=result)],
user_id=user_id,
)
> max_tokens * 0.7
):

View File

@@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
tool_type,
tenant_id,
provider,
tool_name,
tool_parameters,
user_id=user_id,
credential_id=credential_id,
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1

View File

@@ -1,6 +1,6 @@
import binascii
from collections.abc import Generator, Sequence
from typing import IO
from typing import IO, Any
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
class PluginModelClient(BasePluginClient):
@staticmethod
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {"data": data}
if user_id is not None:
payload["user_id"] = user_id
return payload
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
@@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
def get_model_schema(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/schema",
PluginModelSchemaEntity,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
) -> bool:
"""
validate the credentials of the provider
@@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
def validate_model_credentials(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
def invoke_llm(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "llm",
"model": model,
@@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
"stop": stop,
"stream": stream,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
def get_llm_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
@@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
"prompt_messages": prompt_messages,
"tools": tools,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
def invoke_text_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
"texts": texts,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
"documents": documents,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
def get_text_embedding_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"texts": texts,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
def invoke_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
def invoke_tts(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
@@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
"content_text": content_text,
"voice": voice,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
def get_tts_model_voices(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
"credentials": credentials,
"language": language,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
def invoke_speech_to_text(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "speech2text",
"model": model,
"credentials": credentials,
"file": binascii.hexlify(file.read()).decode(),
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
def invoke_moderation(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "moderation",
"model": model,
"credentials": credentials,
"text": text,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,

View File

@@ -0,0 +1,499 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.runtime import ModelRuntime
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
# `TS` means tenant scope
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
tenant_id: str
user_id: str | None
client: PluginModelClient
_provider_entities: tuple[ProviderEntity, ...] | None
_provider_entities_lock: Lock
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
if client is None:
raise ValueError("client is required.")
self.tenant_id = tenant_id
self.user_id = user_id
self.client = client
self._provider_entities = None
self._provider_entities_lock = Lock()
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
with self._provider_entities_lock:
if self._provider_entities is None:
self._provider_entities = tuple(
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
)
return self._provider_entities
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
credentials=credentials,
)
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_model_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None:
cache_key = self._get_schema_cache_key(
provider=provider,
model_type=model_type,
model=model,
credentials=credentials,
)
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
plugin_id, provider_name = self._split_provider(provider)
schema = self.client.get_model_schema(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int:
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
return 0
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
documents=documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
)
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_tts(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
language=language,
)
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
file=file,
)
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_moderation(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
text=text,
)
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = self._get_provider_short_name_alias(provider)
return declaration
def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None)
if provider_entity is None:
provider_entity = next((item for item in providers if provider == item.provider_name), None)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return provider_entity
def _get_schema_cache_key(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> str:
# The plugin daemon distinguishes ``None`` from an explicit empty-string
# caller id, so the cache must only collapse ``None`` into tenant scope.
cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}"
sorted_credentials = sorted(credentials.items()) if credentials else []
if not sorted_credentials:
return cache_key
hashed_credentials = ":".join(
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
)
return f"{cache_key}:{hashed_credentials}"
def _split_provider(self, provider: str) -> tuple[str, str]:
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
tenant_id: str
user_id: str | None
_model_runtime: PluginModelRuntime | None
_model_provider_factory: ModelProviderFactory | None
_provider_manager: ProviderManager | None
_model_manager: ModelManager | None
def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None:
self.tenant_id = tenant_id
self.user_id = user_id
self._model_runtime = None
self._model_provider_factory = None
self._provider_manager = None
self._model_manager = None
@property
def model_runtime(self) -> PluginModelRuntime:
if self._model_runtime is None:
self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id)
return self._model_runtime
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
return self._model_provider_factory
@property
def provider_manager(self) -> ProviderManager:
if self._provider_manager is None:
from core.provider_manager import ProviderManager
self._provider_manager = ProviderManager(model_runtime=self.model_runtime)
return self._provider_manager
@property
def model_manager(self) -> ModelManager:
if self._model_manager is None:
from core.model_manager import ModelManager
self._model_manager = ModelManager(provider_manager=self.provider_manager)
return self._model_manager
def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly:
"""Create a request-scoped assembly that shares one plugin runtime across model views."""
return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id)
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
"""Create a plugin runtime with its client dependency fully composed."""
from core.plugin.impl.model_runtime import PluginModelRuntime
return PluginModelRuntime(
tenant_id=tenant_id,
user_id=user_id,
client=PluginModelClient(),
)
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
"""Create a tenant-bound model provider factory for service flows."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
"""Create a tenant-bound provider manager for service flows."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
"""Create a tenant-bound model manager for service flows."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager

View File

@@ -1,50 +1,7 @@
from typing import Literal
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from pydantic import BaseModel
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
class ChatModelMessage(BaseModel):
"""
Chat Message.
"""
text: str
role: PromptMessageRole
edition_type: Literal["basic", "jinja2"] | None = None
class CompletionModelPromptTemplate(BaseModel):
"""
Completion Model Prompt Template.
"""
text: str
edition_type: Literal["basic", "jinja2"] | None = None
class MemoryConfig(BaseModel):
"""
Memory Config.
"""
class RolePrefix(BaseModel):
"""
Role Prefix.
"""
user: str
assistant: str
class WindowConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: int | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
__all__ = [
"ChatModelMessage",
"CompletionModelPromptTemplate",
"MemoryConfig",
]

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -53,15 +55,25 @@ from models.provider import (
from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from dify_graph.model_runtime.runtime import ModelRuntime
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
ProviderManager manages tenant-scoped model provider configuration.
The runtime adapter is injected by the composition layer so this class stays
focused on configuration assembly instead of constructing plugin runtimes.
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
"""
def __init__(self):
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@@ -127,7 +139,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@@ -255,6 +267,7 @@ class ProviderManager:
custom_configuration=custom_configuration,
model_settings=model_settings,
)
provider_configuration.bind_model_runtime(self._model_runtime)
provider_configurations[str(provider_id_entity)] = provider_configuration
@@ -321,7 +334,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(
@@ -392,7 +405,7 @@ class ProviderManager:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
model_type=model_type.to_origin_model_type(),
provider_name=provider,
model_name=model,
)

View File

@@ -52,11 +52,10 @@ class DataPostProcessor:
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)
@@ -106,9 +105,9 @@ class DataPostProcessor:
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@@ -328,7 +328,7 @@ class RetrievalService:
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
if dataset.is_multimodal:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model["reranking_provider_name"],

View File

@@ -303,7 +303,7 @@ class Vector:
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,

View File

@@ -73,7 +73,7 @@ class DatasetDocumentStore:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,

View File

@@ -21,9 +21,8 @@ logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: str | None = None):
def __init__(self, model_instance: ModelInstance):
self._model_instance = model_instance
self._user = user
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
@@ -65,7 +64,7 @@ class CacheEmbedding(Embeddings):
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
)
for vector in embedding_result.embeddings:
@@ -147,7 +146,6 @@ class CacheEmbedding(Embeddings):
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
@@ -202,7 +200,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
texts=[text], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
@@ -245,7 +243,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]

View File

@@ -8,11 +8,12 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.file_access import DatabaseFileAccessController
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -27,6 +28,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor, Su
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file_reference import build_file_reference
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -48,6 +50,8 @@ from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
@@ -410,7 +414,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)
@@ -555,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
file_obj = build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
access_controller=_file_access_controller,
)
file_objects.append(file_obj)
except Exception as e:
@@ -604,11 +609,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
reference=build_file_reference(
record_id=str(upload_file.id),
),
size=upload_file.size,
storage_key=upload_file.key,
)

View File

@@ -12,7 +12,6 @@ class BaseRerankRunner(ABC):
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
@@ -21,7 +20,6 @@ class BaseRerankRunner(ABC):
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
raise NotImplementedError

View File

@@ -22,7 +22,6 @@ class RerankModelRunner(BaseRerankRunner):
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
@@ -31,10 +30,11 @@ class RerankModelRunner(BaseRerankRunner):
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
@@ -43,12 +43,12 @@ class RerankModelRunner(BaseRerankRunner):
)
if not is_support_vision:
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n)
else:
return documents
else:
rerank_result, unique_documents = self.fetch_multimodal_rerank(
query, documents, score_threshold, top_n, user, query_type
query, documents, score_threshold, top_n, query_type
)
rerank_documents = []
@@ -73,7 +73,6 @@ class RerankModelRunner(BaseRerankRunner):
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch text rerank
@@ -81,7 +80,6 @@ class RerankModelRunner(BaseRerankRunner):
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
docs = []
@@ -103,7 +101,7 @@ class RerankModelRunner(BaseRerankRunner):
unique_documents.append(document)
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
@@ -113,7 +111,6 @@ class RerankModelRunner(BaseRerankRunner):
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> tuple[RerankResult, list[Document]]:
"""
@@ -122,7 +119,6 @@ class RerankModelRunner(BaseRerankRunner):
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:param query_type: query type
:return: rerank result
"""
@@ -168,7 +164,7 @@ class RerankModelRunner(BaseRerankRunner):
documents = unique_documents
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n)
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
@@ -181,7 +177,7 @@ class RerankModelRunner(BaseRerankRunner):
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner):
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
@@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner):
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
@@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner):
"""
query_vector_scores = []
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,

View File

@@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.file_reference import build_file_reference
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.nodes.knowledge_retrieval.retrieval import (
KnowledgeRetrievalRequest,
@@ -160,7 +161,7 @@ class DatasetRetrieval:
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id)
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
@@ -383,23 +384,27 @@ class DatasetRetrieval:
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model schema
# Reuse the caller-bound model instance for both schema resolution and
# downstream planner/invoke calls so a single request never mixes
# tenant-scope and request-bound runtimes.
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
model=model_instance.model_name,
credentials=model_instance.credentials,
)
if not model_schema:
return None, []
model_config.provider_model_bundle = model_instance.provider_model_bundle
model_config.credentials = model_instance.credentials
model_config.model_schema = model_schema
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
@@ -517,11 +522,12 @@ class DatasetRetrieval:
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
reference=build_file_reference(
record_id=str(upload_file.id),
),
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
@@ -986,6 +992,24 @@ class DatasetRetrieval:
)
)
@staticmethod
def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None:
"""Map runtime user source values to dataset query audit roles.
Workflow run context uses the hyphenated ``end-user`` value, while
``DatasetQuery.created_by_role`` persists the underscore-based
``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an
unsupported value should be skipped instead of aborting retrieval.
"""
normalized_user_from = str(user_from).strip().lower().replace("-", "_")
if normalized_user_from == CreatorUserRole.ACCOUNT.value:
return CreatorUserRole.ACCOUNT
if normalized_user_from == CreatorUserRole.END_USER.value:
return CreatorUserRole.END_USER
logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from)
return None
def _on_query(
self,
query: str | None,
@@ -996,10 +1020,13 @@ class DatasetRetrieval:
user_id: str,
):
"""
Handle query.
Persist dataset query audit rows for retrieval requests.
"""
if not query and not attachment_ids:
return
created_by_role = self._resolve_creator_user_role(user_from)
if created_by_role is None:
return
dataset_queries = []
for dataset_id in dataset_ids:
contents = []
@@ -1014,7 +1041,7 @@ class DatasetRetrieval:
content=json.dumps(contents),
source=DatasetQuerySource.APP,
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by_role=created_by_role,
created_by=user_id,
)
dataset_queries.append(dataset_query)
@@ -1411,7 +1438,7 @@ class DatasetRetrieval:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
@@ -1430,7 +1457,6 @@ class DatasetRetrieval:
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
@@ -1533,7 +1559,7 @@ class DatasetRetrieval:
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
@@ -1543,7 +1569,7 @@ class DatasetRetrieval:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)

View File

@@ -3,13 +3,14 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelType
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -119,6 +120,7 @@ class ReactMultiDatasetRouter:
memory_config=None,
memory=None,
model_config=model_config,
model_instance=model_instance,
)
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,
@@ -150,19 +152,24 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=ModelType.LLM,
model=model_instance.model_name,
)
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
return text, usage

View File

@@ -4,7 +4,13 @@ from __future__ import annotations
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
from .factory import (
DifyCoreRepositoryFactory,
OrderConfig,
RepositoryImportError,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository,
)
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
@@ -12,7 +18,10 @@ __all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"OrderConfig",
"RepositoryImportError",
"SQLAlchemyWorkflowExecutionRepository",
"SQLAlchemyWorkflowNodeExecutionRepository",
"WorkflowExecutionRepository",
"WorkflowNodeExecutionRepository",
]

View File

@@ -11,8 +11,8 @@ from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import WorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.enums import WorkflowRunTriggeredFrom

View File

@@ -12,11 +12,11 @@ from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution
from dify_graph.repositories.workflow_node_execution_repository import (
from core.repositories.factory import (
OrderConfig,
WorkflowNodeExecutionRepository,
)
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# For now, we'll re-raise the exception
raise
def get_by_workflow_run(
def get_by_workflow_execution(
self,
workflow_run_id: str,
workflow_execution_id: str,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
Retrieve all workflow node executions for a workflow execution from cache.
Args:
workflow_run_id: The workflow run ID
workflow_execution_id: The workflow execution identifier
order_config: Optional configuration for ordering results
Returns:
A sequence of WorkflowNodeExecution instances
"""
try:
# Get execution IDs for this workflow run from cache
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
# Get execution IDs for this workflow execution from cache
execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, [])
# Retrieve executions from cache
result = []
@@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
for field_name in reversed(order_config.order_by):
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
logger.debug(
"Retrieved %d workflow node executions for execution %s from cache",
len(result),
workflow_execution_id,
)
return result
except Exception:
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
logger.exception(
"Failed to get workflow node executions for execution %s from cache",
workflow_execution_id,
)
return []

View File

@@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation
allowing users to configure different repository backends through string paths.
"""
from typing import Union
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Protocol, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution
from libs.module_loading import import_string
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@dataclass
class OrderConfig:
"""Configuration for ordering node execution instances."""
order_by: list[str]
order_direction: Literal["asc", "desc"] | None = None
class WorkflowExecutionRepository(Protocol):
def save(self, execution: WorkflowExecution): ...
class WorkflowNodeExecutionRepository(Protocol):
def save(self, execution: WorkflowNodeExecution): ...
def save_execution_data(self, execution: WorkflowNodeExecution): ...
def get_by_workflow_execution(
self,
workflow_execution_id: str,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]: ...
class RepositoryImportError(Exception):
"""Raised when a repository implementation cannot be imported or instantiated."""

View File

@@ -2,33 +2,23 @@ import dataclasses
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from typing import Any, Protocol
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
from dify_graph.nodes.human_input.entities import (
from core.workflow.human_input_compat import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
HumanInputNodeData,
MemberRecipient,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import (
DeliveryMethodType,
HumanInputFormKind,
HumanInputFormStatus,
)
from dify_graph.repositories.human_input_form_repository import (
FormCreateParams,
FormNotFoundError,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
InteractiveSurfaceDeliveryMethod,
is_human_input_webapp_enabled,
)
from dify_graph.nodes.human_input.entities import FormDefinition, HumanInputNodeData
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.account import Account, TenantAccountJoin
@@ -36,6 +26,7 @@ from models.human_input import (
BackstageRecipientPayload,
ConsoleDeliveryPayload,
ConsoleRecipientPayload,
DeliveryMethodType,
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputDelivery,
@@ -58,6 +49,65 @@ class _WorkspaceMemberInfo:
email: str
class FormNotFoundError(Exception):
pass
@dataclasses.dataclass
class FormCreateParams:
workflow_execution_id: str | None
node_id: str
form_config: HumanInputNodeData
rendered_content: str
delivery_methods: Sequence[DeliveryChannelConfig]
display_in_ui: bool
resolved_default_values: Mapping[str, Any]
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
class HumanInputFormRecipientEntity(Protocol):
@property
def id(self) -> str: ...
@property
def token(self) -> str: ...
class HumanInputFormEntity(Protocol):
@property
def id(self) -> str: ...
@property
def submission_token(self) -> str | None: ...
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]: ...
@property
def rendered_content(self) -> str: ...
@property
def selected_action_id(self) -> str | None: ...
@property
def submitted_data(self) -> Mapping[str, Any] | None: ...
@property
def submitted(self) -> bool: ...
@property
def status(self) -> HumanInputFormStatus: ...
@property
def expiration_time(self) -> datetime: ...
class HumanInputFormRepository(Protocol):
def get_form(self, node_id: str) -> HumanInputFormEntity | None: ...
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ...
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
def __init__(self, recipient_model: HumanInputFormRecipient):
self._recipient_model = recipient_model
@@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
self._form_model = form_model
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
self._web_app_recipient = next(
self._interactive_surface_recipient = next(
(
recipient
for recipient in recipient_models
@@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
return self._form_model.id
@property
def web_app_token(self):
def submission_token(self) -> str | None:
if self._console_recipient is not None:
return self._console_recipient.access_token
if self._web_app_recipient is None:
if self._interactive_surface_recipient is None:
return None
return self._web_app_recipient.access_token
return self._interactive_surface_recipient.access_token
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]:
@@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl:
self,
*,
tenant_id: str,
):
app_id: str | None = None,
workflow_execution_id: str | None = None,
invoke_source: str | None = None,
submission_actor_id: str | None = None,
) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._workflow_execution_id = workflow_execution_id
self._invoke_source = invoke_source
self._submission_actor_id = submission_actor_id
def _delivery_method_to_model(
self,
@@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl:
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
if isinstance(delivery_method, WebAppDeliveryMethod):
if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod):
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
@@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl:
delivery_id: str,
recipients_config: EmailRecipients,
) -> list[HumanInputFormRecipient]:
member_user_ids = [
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
bound_reference_ids = [
recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient)
]
external_emails = [
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
]
if recipients_config.whole_workspace:
if recipients_config.include_bound_group:
members = self._query_all_workspace_members(session=session)
else:
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids)
return self._create_email_recipients_from_resolved(
form_id=form_id,
@@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl:
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def _should_create_console_recipient(
self,
*,
form_config: HumanInputNodeData,
form_kind: HumanInputFormKind,
) -> bool:
if form_kind != HumanInputFormKind.RUNTIME:
return False
if self._invoke_source == "debugger":
return True
if self._invoke_source == "explore":
return is_human_input_webapp_enabled(form_config)
return False
def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool:
return form_kind == HumanInputFormKind.RUNTIME and (
self._invoke_source is not None or self._submission_actor_id is not None
)
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
app_id = self._app_id
if not app_id:
raise ValueError("app_id is required to create a human input form")
workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id
if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None:
raise ValueError("workflow_execution_id is required for runtime human input forms")
with session_factory.create_session() as session, session.begin():
# Generate unique form ID
@@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl:
form_model = HumanInputForm(
id=form_id,
tenant_id=self._tenant_id,
app_id=params.app_id,
workflow_run_id=params.workflow_execution_id,
app_id=app_id,
workflow_run_id=workflow_execution_id,
form_kind=params.form_kind,
node_id=params.node_id,
form_definition=form_definition.model_dump_json(),
@@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl:
session.add(delivery_and_recipients.delivery)
session.add_all(delivery_and_recipients.recipients)
recipient_models.extend(delivery_and_recipients.recipients)
if params.console_recipient_required and not any(
if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any(
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
):
console_delivery_id = str(uuidv7())
@@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl:
delivery_id=console_delivery_id,
recipient_type=RecipientType.CONSOLE,
recipient_payload=ConsoleRecipientPayload(
account_id=params.console_creator_account_id,
account_id=self._submission_actor_id,
).model_dump_json(),
)
session.add(console_delivery)
session.add(console_recipient)
recipient_models.append(console_recipient)
if params.backstage_recipient_required and not any(
if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any(
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
):
backstage_delivery_id = str(uuidv7())
@@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl:
delivery_id=backstage_delivery_id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload(
account_id=params.console_creator_account_id,
account_id=self._submission_actor_id,
).model_dump_json(),
)
session.add(backstage_delivery)
@@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl:
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
def get_form(self, node_id: str) -> HumanInputFormEntity | None:
if self._workflow_execution_id is None:
raise ValueError("workflow_execution_id is required to load runtime human input forms")
form_query = select(HumanInputForm).where(
HumanInputForm.workflow_run_id == workflow_execution_id,
HumanInputForm.workflow_run_id == self._workflow_execution_id,
HumanInputForm.node_id == node_id,
HumanInputForm.tenant_id == self._tenant_id,
)

View File

@@ -9,9 +9,9 @@ from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import WorkflowExecutionRepository
from dify_graph.entities import WorkflowExecution
from dify_graph.enums import WorkflowExecutionStatus, WorkflowType
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
from models import (

Some files were not shown because too many files have changed in this diff Show More