Compare commits

..

1 Commits

Author SHA1 Message Date
-LAN-
0d3aab5901 refactor(api): move TokenBufferMemory to model_runtime 2026-02-28 18:02:39 +08:00
318 changed files with 2788 additions and 3177 deletions

View File

@@ -68,9 +68,10 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + mypy)..."
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@cd api && uv run ty check
@echo "✅ Type checks complete"
test:
@@ -131,7 +132,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, mypy)"
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@@ -91,7 +91,7 @@ forbidden_modules =
core.moderation
core.ops
core.plugin
core.model_runtime.prompt
core.prompt
core.provider_manager
core.rag
core.repositories
@@ -106,9 +106,11 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.provider_manager
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
@@ -127,10 +129,14 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.simple_prompt_transform
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.simple_prompt_transform
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc
@@ -139,25 +145,33 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.variables.segments
core.workflow.nodes.loop.entities -> core.variables.types
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
core.workflow.nodes.tool.tool_node -> models
core.workflow.nodes.agent.agent_node -> models.model
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
core.workflow.nodes.llm.node -> core.helper.code_executor
core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.question_classifier.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
@@ -169,6 +183,54 @@ ignore_imports =
core.workflow.nodes.llm.file_saver -> core.tools.signature
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
core.workflow.nodes.tool.tool_node -> core.tools.errors
core.workflow.conversation_variable_updater -> core.variables
core.workflow.graph_engine.entities.commands -> core.variables.variables
core.workflow.nodes.agent.agent_node -> core.variables.segments
core.workflow.nodes.answer.answer_node -> core.variables
core.workflow.nodes.code.code_node -> core.variables.segments
core.workflow.nodes.code.code_node -> core.variables.types
core.workflow.nodes.code.entities -> core.variables.types
core.workflow.nodes.document_extractor.node -> core.variables
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
core.workflow.nodes.http_request.node -> core.variables.segments
core.workflow.nodes.human_input.entities -> core.variables.consts
core.workflow.nodes.iteration.iteration_node -> core.variables
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
core.workflow.nodes.list_operator.node -> core.variables
core.workflow.nodes.list_operator.node -> core.variables.segments
core.workflow.nodes.llm.node -> core.variables
core.workflow.nodes.loop.loop_node -> core.variables
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
core.workflow.nodes.tool.tool_node -> core.variables.segments
core.workflow.nodes.tool.tool_node -> core.variables.variables
core.workflow.nodes.trigger_webhook.node -> core.variables.types
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
core.workflow.nodes.variable_assigner.v1.node -> core.variables
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
core.workflow.nodes.variable_assigner.v2.node -> core.variables
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
core.workflow.runtime.read_only_wrappers -> core.variables.segments
core.workflow.runtime.variable_pool -> core.variables
core.workflow.runtime.variable_pool -> core.variables.consts
core.workflow.runtime.variable_pool -> core.variables.segments
core.workflow.runtime.variable_pool -> core.variables.variables
core.workflow.utils.condition.processor -> core.variables
core.workflow.utils.condition.processor -> core.variables.segments
core.workflow.variable_loader -> core.variables
core.workflow.variable_loader -> core.variables.consts
core.workflow.workflow_type_encoder -> core.variables
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
@@ -185,17 +247,17 @@ ignore_imports =
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services
core.workflow.nodes.tool.tool_node -> services
core.workflow.nodes.agent.agent_node -> core.model_runtime.token_buffer_memory
core.workflow.nodes.llm.llm_utils -> core.model_runtime.token_buffer_memory
core.workflow.nodes.llm.node -> core.model_runtime.token_buffer_memory
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.token_buffer_memory
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.token_buffer_memory
[importlinter:contract:model-runtime-no-internal-imports]
name = Model Runtime Internal Imports
type = forbidden
source_modules =
core.model_runtime.callbacks
core.model_runtime.entities
core.model_runtime.errors
core.model_runtime.model_providers
core.model_runtime.schema_validators
core.model_runtime.utils
core.model_runtime
forbidden_modules =
configs
controllers
@@ -225,7 +287,7 @@ forbidden_modules =
core.moderation
core.ops
core.plugin
core.model_runtime.prompt
core.prompt
core.provider_manager
core.rag
core.repositories
@@ -242,6 +304,13 @@ ignore_imports =
core.model_runtime.model_providers.model_provider_factory -> configs
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
core.model_runtime.token_buffer_memory -> core.app.app_config.features.file_upload.manager
core.model_runtime.token_buffer_memory -> core.model_manager
core.model_runtime.token_buffer_memory -> core.prompt.utils.extract_thread_messages
core.model_runtime.token_buffer_memory -> core.workflow.file.file_manager
core.model_runtime.token_buffer_memory -> extensions.ext_database
core.model_runtime.token_buffer_memory -> models.model
core.model_runtime.token_buffer_memory -> models.workflow
[importlinter:contract:rsc]
name = RSC

File diff suppressed because one or more lines are too long

View File

@@ -15,11 +15,11 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.file import helpers as file_helpers
from core.workflow.variables.segment_group import SegmentGroup
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type

View File

@@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type

View File

@@ -36,9 +36,9 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
@@ -214,9 +214,9 @@ def cloud_utm_record(view: Callable[P, R]):
return decorated
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
def setup_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"

View File

@@ -17,7 +17,6 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -32,7 +31,8 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,

View File

@@ -17,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine

View File

@@ -21,7 +21,7 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.file import file_manager

View File

@@ -5,7 +5,7 @@ from core.app.app_config.entities import (
PromptTemplateEntity,
)
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.prompt.simple_prompt_transform import ModelMode
from models.model import AppMode

View File

@@ -32,8 +32,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.model_runtime.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (

View File

@@ -25,6 +25,7 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -33,7 +34,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.variables.variables import Variable
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client

View File

@@ -12,11 +12,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.moderation.base import ModerationError
from extensions.ext_database import db
from models.model import App, Conversation, Message

View File

@@ -22,7 +22,6 @@ from core.app.entities.queue_entities import (
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
@@ -33,14 +32,11 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file.enums import FileTransferMethod, FileType
from extensions.ext_database import db

View File

@@ -11,9 +11,9 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file import File

View File

@@ -49,6 +49,7 @@ from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
@@ -61,7 +62,6 @@ from core.workflow.enums import (
from core.workflow.file import FILE_MODEL_IDENTITY, File
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db

View File

@@ -27,7 +27,7 @@ from core.app.entities.task_entities import (
CompletionAppStreamResponse,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic

View File

@@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
@@ -20,7 +21,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline

View File

@@ -1,12 +1,12 @@
import logging
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.variables import VariableBase
logger = logging.getLogger(__name__)

View File

@@ -83,21 +83,14 @@ def fetch_model_config(
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
@@ -105,6 +98,6 @@ def fetch_model_config(
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=completion_params,
parameters=node_data_model.completion_params,
stop=stop,
)

View File

@@ -52,10 +52,10 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from core.workflow.file import helpers as file_helpers
from core.workflow.file.enums import FileTransferMethod

View File

@@ -1,20 +1,14 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from typing import TYPE_CHECKING, Any, final
from typing_extensions import override
from configs import dify_config
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
@@ -29,11 +23,7 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@@ -82,6 +72,7 @@ class DifyNodeFactory(NodeFactory):
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers()
self._code_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
@@ -153,6 +144,7 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
@@ -179,8 +171,6 @@ class DifyNodeFactory(NodeFactory):
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
@@ -188,8 +178,6 @@ class DifyNodeFactory(NodeFactory):
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.DATASOURCE:
@@ -220,7 +208,6 @@ class DifyNodeFactory(NodeFactory):
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@@ -228,11 +215,9 @@ class DifyNodeFactory(NodeFactory):
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@@ -240,7 +225,6 @@ class DifyNodeFactory(NodeFactory):
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
)
return node_class(
@@ -249,55 +233,3 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
node_data_model = ModelConfig.model_validate(node_data["model"])
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,
)

View File

@@ -1,5 +1,6 @@
import logging
from collections.abc import Mapping
from enum import StrEnum
from threading import Lock
from typing import Any
@@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
from core.helper.code_executor.template_transformer import TemplateTransformer
from core.helper.http_client_pooling import get_pooled_http_client
from core.workflow.nodes.code.entities import CodeLanguage
logger = logging.getLogger(__name__)
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
@@ -40,6 +40,12 @@ class CodeExecutionResponse(BaseModel):
data: Data
class CodeLanguage(StrEnum):
PYTHON3 = "python3"
JINJA2 = "jinja2"
JAVASCRIPT = "javascript"
def _build_code_executor_client() -> httpx.Client:
return httpx.Client(
verify=CODE_EXECUTION_SSL_VERIFY,

View File

@@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
from core.workflow.variables.utils import dumps_with_segments
from core.variables.utils import dumps_with_segments
class TemplateTransformer(ABC):

View File

@@ -27,10 +27,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage

View File

@@ -1,5 +1,5 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config
@@ -38,9 +38,6 @@ class ModelInstance:
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
self.model_type_instance = self.provider_model_bundle.model_type_instance
self.load_balancing_manager = self._get_load_balancing_manager(
configuration=provider_model_bundle.configuration,

View File

@@ -83,21 +83,19 @@ def _merge_tool_call_delta(
tool_call.function.arguments += delta.function.arguments
def _build_llm_result_from_chunks(
def _build_llm_result_from_first_chunk(
model: str,
prompt_messages: Sequence[PromptMessage],
chunks: Iterator[LLMResultChunk],
) -> LLMResult:
"""
Build a single `LLMResult` by accumulating all returned chunks.
Build a single `LLMResult` from the first returned chunk.
Some models only support streaming output (e.g. Qwen3 open-source edition)
and the plugin side may still implement the response via a chunked stream,
so all chunks must be consumed and concatenated into a single ``LLMResult``.
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
The ``usage`` is taken from the last chunk that carries it, which is the
typical convention for streaming responses (the final chunk contains the
aggregated token counts).
Note:
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
"""
content = ""
content_list: list[PromptMessageContentUnionTypes] = []
@@ -106,27 +104,24 @@ def _build_llm_result_from_chunks(
tools_calls: list[AssistantPromptMessage.ToolCall] = []
try:
for chunk in chunks:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
if chunk.delta.message.tool_calls:
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
if chunk.delta.usage:
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception:
logger.exception("Error while consuming non-stream plugin chunk iterator.")
raise
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
finally:
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
close = getattr(chunks, "close", None)
if callable(close):
close()
try:
for _ in chunks:
pass
except Exception:
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
return LLMResult(
model=model,
@@ -179,7 +174,7 @@ def _normalize_non_stream_plugin_result(
) -> LLMResult:
if isinstance(result, LLMResult):
return result
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
def _increase_tool_call(

View File

@@ -14,7 +14,7 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.workflow.file import file_manager
from extensions.ext_database import db
from factories import file_factory

View File

@@ -14,9 +14,10 @@ class BaseTraceInstance(ABC):
Base trace instance for ops trace services
"""
@abstractmethod
def __init__(self, trace_config: BaseTracingConfig):
"""
Initializer for the trace instance.
Abstract initializer for the trace instance.
Distribute trace tasks by matching entities
"""
self.trace_config = trace_config

View File

@@ -41,8 +41,8 @@ logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, key: str) -> dict[str, Any]:
match key:
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
@@ -149,7 +149,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
}
case _:
raise KeyError(f"Unsupported tracing provider: {key}")
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map = OpsTraceProviderConfigMap()

View File

@@ -3,8 +3,6 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@@ -14,13 +12,10 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
@@ -49,8 +44,7 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
prompt_messages = []
@@ -65,7 +59,6 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
@@ -78,7 +71,6 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
@@ -93,8 +85,7 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
@@ -120,7 +111,6 @@ class AdvancedPromptTransform(PromptTransform):
parser=parser,
prompt_inputs=prompt_inputs,
model_config=model_config,
model_instance=model_instance,
)
if query:
@@ -156,8 +146,7 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
@@ -209,13 +198,8 @@ class AdvancedPromptTransform(PromptTransform):
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if memory and memory_config:
prompt_messages = self._append_chat_histories(
memory,
memory_config,
prompt_messages,
model_config=model_config,
model_instance=model_instance,
)
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files and query is not None:
for file in files:
prompt_message_contents.append(
@@ -292,8 +276,7 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix: MemoryConfig.RolePrefix,
parser: PromptTemplateParser,
prompt_inputs: Mapping[str, str],
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
model_config: ModelConfigWithCredentialsEntity,
) -> Mapping[str, str]:
prompt_inputs = dict(prompt_inputs)
if "#histories#" in parser.variable_keys:
@@ -303,11 +286,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
rest_tokens = self._calculate_rest_token(
[tmp_human_message],
model_config=model_config,
model_instance=model_instance,
)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory(
memory=memory,

View File

@@ -3,14 +3,14 @@ from typing import cast
from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.prompt_transform import PromptTransform
class AgentHistoryPromptTransform(PromptTransform):
@@ -41,7 +41,7 @@ class AgentHistoryPromptTransform(PromptTransform):
if not self.memory:
return prompt_messages
max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config)
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

View File

@@ -1,86 +1,48 @@
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:
def _resolve_model_runtime(
self,
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> tuple[ModelInstance, AIModelEntity]:
if model_instance is None:
if model_config is None:
raise ValueError("Either model_config or model_instance must be provided.")
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_instance.credentials = model_config.credentials
model_instance.parameters = model_config.parameters
model_instance.stop = model_config.stop
model_schema = model_instance.model_type_instance.get_model_schema(
model=model_instance.model_name,
credentials=model_instance.credentials,
)
if model_schema is None:
if model_config is None:
raise ValueError("Model schema not found for the provided model instance.")
model_schema = model_config.model_schema
return model_instance, model_schema
def _append_chat_histories(
self,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(
prompt_messages,
model_config=model_config,
model_instance=model_instance,
)
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
prompt_messages.extend(histories)
return prompt_messages
def _calculate_rest_token(
self,
prompt_messages: list[PromptMessage],
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
model_instance, model_schema = self._resolve_model_runtime(
model_config=model_config,
model_instance=model_instance,
)
model_parameters = model_instance.parameters
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_parameters.get(parameter_rule.name)
or model_parameters.get(parameter_rule.use_template or "")
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, cast
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
@@ -15,9 +14,10 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from models.model import AppMode
@@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform):
if memory:
tmp_human_message = UserPromptMessage(content=prompt)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory(
memory=memory,
memory_config=MemoryConfig(

View File

@@ -1,6 +1,6 @@
from sqlalchemy import select
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message

View File

@@ -10,7 +10,7 @@ from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil:

View File

@@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
top_k=1,
filter=f"ref_doc_id='{id}'",
)
@@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
@@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
@@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
content=None, # ty: ignore [invalid-argument-type]
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
@@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
vector=None, # ty: ignore [invalid-argument-type]
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,

View File

@@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@@ -75,15 +75,15 @@ class BaseIndexProcessor(ABC):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
) -> None:
):
raise NotImplementedError
@abstractmethod
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
raise NotImplementedError
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
raise NotImplementedError
@abstractmethod

View File

@@ -115,7 +115,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
) -> None:
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
@@ -130,7 +130,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
@@ -196,7 +196,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
@@ -469,7 +469,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if not isinstance(result, LLMResult):
raise ValueError("Expected LLMResult when stream=False")
summary_content = result.message.get_text_content()
summary_content = getattr(result.message, "content", "")
usage = result.usage
# Deduct quota for summary generation (same as workflow nodes)

View File

@@ -126,7 +126,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
) -> None:
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
@@ -272,7 +272,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes.append(child_document)
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
parent_childs = ParentChildStructureChunk.model_validate(chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:

View File

@@ -139,14 +139,14 @@ class QAIndexProcessor(BaseIndexProcessor):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
) -> None:
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
@@ -206,7 +206,7 @@ class QAIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
qa_chunks = QAStructureChunk.model_validate(chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:

View File

@@ -23,18 +23,18 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.db.session_factory import session_factory
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService

View File

@@ -5,8 +5,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils

View File

@@ -4,6 +4,8 @@ from uuid import uuid4
from pydantic import BaseModel, Discriminator, Field, Tag
from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayBooleanSegment,
@@ -25,14 +27,6 @@ from .segments import (
from .types import SegmentType
def _obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:
return "*" * 20
return token[:6] + "*" * 12 + token[-2:]
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -92,7 +86,7 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
return _obfuscated_token(self.value)
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, VariableBase):

View File

@@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.workflow.variables import VariableBase
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):

View File

@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.workflow.variables.variables import Variable
from core.variables.variables import Variable
class CommandType(StrEnum):

View File

@@ -9,6 +9,7 @@ from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@@ -76,10 +77,13 @@ class GraphEngine:
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
@@ -159,6 +163,7 @@ class GraphEngine:
layers=self._layers,
execution_context=execution_context,
config=self._config,
stop_event=self._stop_event,
)
# === Orchestration ===
@@ -189,6 +194,7 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@@ -308,6 +314,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@@ -341,6 +348,7 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

View File

@@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@@ -69,14 +70,12 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)

View File

@@ -42,6 +42,7 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
execution_context: IExecutionContext | None = None,
) -> None:
@@ -62,13 +63,16 @@ class Worker(threading.Thread):
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._stop_event = threading.Event()
self._stop_event = stop_event
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
@property
def is_idle(self) -> bool:

View File

@@ -37,6 +37,7 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
config: GraphEngineConfig,
execution_context: IExecutionContext | None = None,
) -> None:
@@ -63,6 +64,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@@ -133,6 +135,7 @@ class WorkerPool:
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
worker.start()

View File

@@ -11,10 +11,10 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
@@ -25,6 +25,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@@ -43,7 +44,6 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData

View File

@@ -1,13 +1,13 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.variables import ArrayFileSegment, FileSegment, Segment
class AnswerNode(Node[AnswerNodeData]):

View File

@@ -302,6 +302,10 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@@ -370,6 +374,21 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@@ -1,15 +1,17 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.variables.segments import ArrayFileSegment
from core.workflow.variables.types import SegmentType
from .exc import (
CodeNodeError,
@@ -34,44 +36,12 @@ class WorkflowCodeExecutor(Protocol):
def is_execution_error(self, error: Exception) -> bool: ...
def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]:
return {
"type": "code",
"config": {
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": language,
"code": code,
"outputs": {"result": {"type": "string", "children": None}},
},
}
_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
CodeLanguage.PYTHON3: dedent(
"""
def main(arg1: str, arg2: str):
return {
"result": arg1 + arg2,
}
"""
),
CodeLanguage.JAVASCRIPT: dedent(
"""
function main({arg1, arg2}) {
return {
result: arg1 + arg2
}
}
"""
),
}
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
@@ -82,6 +52,7 @@ class CodeNode(Node[CodeNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: WorkflowCodeExecutor,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
@@ -91,6 +62,9 @@ class CodeNode(Node[CodeNodeData]):
graph_runtime_state=graph_runtime_state,
)
self._code_executor: WorkflowCodeExecutor = code_executor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._limits = code_limits
@classmethod
@@ -104,10 +78,15 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language)
if default_code is None:
raise CodeNodeError(f"Unsupported code language: {code_language}")
return _build_default_config(language=code_language, code=default_code)
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
@classmethod
def version(cls) -> str:
@@ -129,6 +108,7 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
_ = self._select_code_provider(code_language)
result = self._code_executor.execute(
language=code_language,
code=code,
@@ -150,6 +130,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string

View File

@@ -1,18 +1,11 @@
from enum import StrEnum
from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.variables.types import SegmentType
class CodeLanguage(StrEnum):
PYTHON3 = "python3"
JINJA2 = "jinja2"
JAVASCRIPT = "javascript"
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[

View File

@@ -21,12 +21,12 @@ from docx.table import Table
from docx.text.paragraph import Paragraph
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.file import File, FileTransferMethod, file_manager
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables import ArrayFileSegment
from core.workflow.variables.segments import ArrayStringSegment, FileSegment
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError

View File

@@ -10,9 +10,11 @@ from urllib.parse import urlencode, urlparse
import httpx
from json_repair import repair_json
from core.helper.ssrf_proxy import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.file.enums import FileTransferMethod
from core.workflow.file.file_manager import file_manager as default_file_manager
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayFileSegment, FileSegment
from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
@@ -79,8 +81,8 @@ class Executor:
http_request_config: HttpRequestNodeConfig,
max_retries: int | None = None,
ssl_verify: bool | None = None,
http_client: HttpClientProtocol,
file_manager: FileManagerProtocol,
http_client: HttpClientProtocol | None = None,
file_manager: FileManagerProtocol | None = None,
):
self._http_request_config = http_request_config
# If authorization API key is present, convert the API key using the variable pool
@@ -114,8 +116,8 @@ class Executor:
self.max_retries = (
max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries
)
self._http_client = http_client
self._file_manager = file_manager
self._http_client = http_client or ssrf_proxy
self._file_manager = file_manager or default_file_manager
# init template
self.variable_pool = variable_pool

View File

@@ -3,15 +3,18 @@ import mimetypes
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.file import File, FileTransferMethod
from core.workflow.file.file_manager import file_manager as default_file_manager
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol
from core.workflow.variables.segments import ArrayFileSegment
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .config import build_http_request_config, resolve_http_request_config
@@ -42,9 +45,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
http_request_config: HttpRequestNodeConfig,
http_client: HttpClientProtocol,
tool_file_manager_factory: Callable[[], ToolFileManagerProtocol],
file_manager: FileManagerProtocol,
http_client: HttpClientProtocol | None = None,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol | None = None,
) -> None:
super().__init__(
id=id,
@@ -52,11 +55,10 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_request_config = http_request_config
self._http_client = http_client
self._http_client = http_client or ssrf_proxy
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
self._file_manager = file_manager or default_file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:

View File

@@ -10,10 +10,10 @@ from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from core.workflow.variables.consts import SELECTORS_LENGTH
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit

View File

@@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -33,9 +36,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from core.workflow.variables import IntegerVariable, NoneSegment
from core.workflow.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.variables.variables import Variable
from libs.datetime_utils import naive_utc_now
from .exc import (

View File

@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
NodeType,
@@ -16,12 +22,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from core.workflow.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from core.workflow.variables.segments import ArrayObjectSegment
from .entities import KnowledgeRetrievalNodeData
from .exc import (

View File

@@ -1,12 +1,12 @@
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, TypeVar
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError

View File

@@ -4,11 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector

View File

@@ -5,17 +5,21 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
@@ -25,14 +29,46 @@ from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:

View File

@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@@ -36,12 +37,21 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
@@ -62,16 +72,8 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
@@ -81,6 +83,7 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from .exc import (
InvalidContextStructureError,
@@ -113,8 +116,6 @@ class LLMNode(Node[LLMNodeData]):
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@@ -125,8 +126,6 @@ class LLMNode(Node[LLMNodeData]):
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@@ -140,8 +139,6 @@ class LLMNode(Node[LLMNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@@ -205,12 +202,29 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance = self._model_instance
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop
model_instance, model_config = self._fetch_model_config(
node_data_model=self.node_data.model,
)
model_name = getattr(model_instance, "model_name", None)
if not isinstance(model_name, str):
model_name = model_config.model
model_provider = getattr(model_instance, "provider", None)
if not isinstance(model_provider, str):
model_provider = model_config.provider
model_schema = model_instance.model_type_instance.get_model_schema(
model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_name}")
memory = self._memory
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)
query: str | None = None
if self.node_data.memory:
@@ -226,7 +240,9 @@ class LLMNode(Node[LLMNodeData]):
context=context,
memory=memory,
model_instance=model_instance,
stop=model_stop,
model_schema=model_schema,
model_parameters=self.node_data.model.completion_params,
stop=model_config.stop,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
@@ -238,6 +254,7 @@ class LLMNode(Node[LLMNodeData]):
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@@ -354,6 +371,7 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
@@ -366,10 +384,11 @@ class LLMNode(Node[LLMNodeData]):
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_parameters = model_instance.parameters
invoke_model_parameters = dict(model_parameters)
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
)
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
@@ -383,7 +402,7 @@ class LLMNode(Node[LLMNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=output_schema,
model_parameters=invoke_model_parameters,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=user_id,
@@ -393,7 +412,7 @@ class LLMNode(Node[LLMNodeData]):
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=invoke_model_parameters,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=user_id,
@@ -752,14 +771,33 @@ class LLMNode(Node[LLMNodeData]):
return None
def _fetch_model_config(
self,
*,
node_data_model: ModelConfig,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
completion_params = model_config_with_cred.parameters
model_config_with_cred.parameters = completion_params
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
node_data_model.completion_params = completion_params
return model, model_config_with_cred
@staticmethod
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: PromptMessageMemory | None = None,
memory: TokenBufferMemory | None = None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
@@ -770,7 +808,6 @@ class LLMNode(Node[LLMNodeData]):
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
@@ -789,6 +826,8 @@ class LLMNode(Node[LLMNodeData]):
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
@@ -826,6 +865,8 @@ class LLMNode(Node[LLMNodeData]):
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
@@ -1275,23 +1316,23 @@ def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
model_parameters.get(parameter_rule.name)
or model_parameters.get(str(parameter_rule.use_template))
or 0
)
@@ -1303,9 +1344,11 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
@@ -1313,6 +1356,8 @@ def _handle_memory_chat_mode(
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
@@ -1323,9 +1368,11 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> str:
memory_text = ""
# Get history text from memory for completion model
@@ -1333,51 +1380,20 @@ def _handle_memory_completion_mode(
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_messages = memory.get_history_prompt_messages(
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessage
class CredentialsProvider(Protocol):
@@ -21,13 +19,3 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages for LLM nodes."""
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@@ -3,9 +3,9 @@ from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.variables.types import SegmentType
_VALID_VAR_TYPE = frozenset(
[

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
NodeExecutionType,
NodeType,
@@ -30,7 +31,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from core.workflow.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now

View File

@@ -7,10 +7,10 @@ from pydantic import (
field_validator,
)
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
from core.workflow.variables.types import SegmentType
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"

View File

@@ -1,6 +1,6 @@
from typing import Any
from core.workflow.variables.types import SegmentType
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):

View File

@@ -5,7 +5,7 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.memory.token_buffer_memory import TokenBufferMemory
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -19,19 +19,20 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from core.workflow.variables.types import ArrayValidation, SegmentType
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
@@ -94,7 +95,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: ModelInstance
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
@@ -107,7 +109,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
) -> None:
super().__init__(
id=id,
@@ -117,7 +118,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -155,14 +155,18 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else []
)
model_instance = self._model_instance
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(
model=model_config.model,
credentials=model_config.credentials,
)
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@@ -180,7 +184,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_instance=model_instance,
model_config=model_config,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
@@ -191,7 +195,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_instance=model_instance,
model_config=model_config,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
@@ -207,23 +211,24 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
}
process_data = {
"model_mode": node_data.model.mode,
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
try:
text, usage, tool_call = self._invoke(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
tools=prompt_message_tools,
stop=model_instance.stop,
stop=model_config.stop,
)
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
@@ -285,16 +290,17 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def _invoke(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: Sequence[str],
stop: list[str],
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=dict(model_instance.parameters),
model_parameters=node_data_model.completion_params,
tools=tools,
stop=list(stop),
stop=stop,
stream=False,
user=self.user_id,
)
@@ -318,7 +324,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -331,13 +337,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
@@ -349,7 +349,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context="",
memory_config=node_data.memory,
memory=None,
model_instance=model_instance,
model_config=model_config,
image_detail_config=vision_detail,
)
@@ -406,7 +406,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -421,7 +421,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data=data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
model_config=model_config,
memory=memory,
files=files,
vision_detail=vision_detail,
@@ -431,7 +431,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data=data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
model_config=model_config,
memory=memory,
files=files,
vision_detail=vision_detail,
@@ -444,7 +444,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -454,11 +454,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
@@ -471,7 +467,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context="",
memory_config=node_data.memory,
memory=memory,
model_instance=model_instance,
model_config=model_config,
image_detail_config=vision_detail,
)
@@ -482,7 +478,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -492,11 +488,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
@@ -516,7 +508,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context="",
memory_config=node_data.memory,
memory=None,
model_instance=model_instance,
model_config=model_config,
image_detail_config=vision_detail,
)
@@ -777,16 +769,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
context: str | None,
) -> int:
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
@@ -799,28 +796,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context=context,
memory_config=node_data.memory,
memory=None,
model_instance=model_instance,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = (
model_type_instance.get_num_tokens(
model_instance.model_name, model_instance.credentials, prompt_messages
)
+ 1000
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
) # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@@ -828,6 +824,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = llm_utils.fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -27,16 +27,3 @@ class HttpClientProtocol(Protocol):
class FileManagerProtocol(Protocol):
def download(self, f: File, /) -> bytes: ...
class ToolFileManagerProtocol(Protocol):
def create_file_by_raw(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
) -> Any: ...

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig

View File

@@ -3,12 +3,14 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.memory.token_buffer_memory import TokenBufferMemory
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
NodeExecutionType,
@@ -20,12 +22,7 @@ from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -55,7 +52,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
def __init__(
self,
@@ -66,7 +62,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@@ -80,7 +75,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@@ -101,8 +95,18 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# fetch model config
model_instance, model_config = llm_utils.fetch_model_config(
node_data_model=node_data.model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
model_schema = model_instance.model_type_instance.get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@@ -127,7 +131,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_instance=model_instance,
model_config=model_config,
context="",
)
prompt_template = self._get_prompt_template(
@@ -145,7 +149,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
sys_query="",
memory=memory,
model_instance=model_instance,
stop=model_instance.stop,
model_schema=model_schema,
model_parameters=node_data.model.completion_params,
stop=model_config.stop,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
@@ -160,6 +166,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@@ -198,14 +205,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
category_name = classes_map[category_id_result]
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
outputs = {
"class_name": category_name,
@@ -278,40 +285,39 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
context: str | None,
) -> int:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages, _ = LLMNode.fetch_prompt_messages(
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
sys_query="",
sys_files=[],
inputs={},
query="",
files=[],
context=context,
memory=None,
model_instance=model_instance,
stop=model_instance.stop,
memory_config=node_data.memory,
vision_enabled=False,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
memory=None,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens

View File

@@ -11,6 +11,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@@ -21,8 +23,6 @@ from core.workflow.file import File, FileTransferMethod
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.workflow.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile

View File

@@ -2,14 +2,14 @@ import logging
from collections.abc import Mapping
from typing import Any
from core.variables.types import SegmentType
from core.variables.variables import FileVariable
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.file import FileTransferMethod
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables.types import SegmentType
from core.workflow.variables.variables import FileVariable
from factories import file_factory
from factories.variable_factory import build_segment_with_type

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.types import SegmentType
class AdvancedSettings(BaseModel):

View File

@@ -1,10 +1,10 @@
from collections.abc import Mapping
from core.variables.segments import Segment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData
from core.workflow.variables.segments import Segment
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):

View File

@@ -3,9 +3,9 @@ from typing import Any, TypeVar
from pydantic import BaseModel
from core.workflow.variables import Segment
from core.workflow.variables.consts import SELECTORS_LENGTH
from core.workflow.variables.types import SegmentType
from core.variables import Segment
from core.variables.consts import SELECTORS_LENGTH
from core.variables.types import SegmentType
# Use double underscore (`__`) prefix for internal variables
# to minimize risk of collision with user-defined variable names.

Some files were not shown because too many files have changed in this diff Show More