mirror of
https://github.com/langgenius/dify.git
synced 2026-04-07 15:59:19 +08:00
Compare commits
1 Commits
move-core-
...
move-token
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d3aab5901 |
5
Makefile
5
Makefile
@@ -68,9 +68,10 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + mypy)..."
|
||||
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@cd api && uv run ty check
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
test:
|
||||
@@ -131,7 +132,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@@ -91,7 +91,7 @@ forbidden_modules =
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.model_runtime.prompt
|
||||
core.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
@@ -106,9 +106,11 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -127,10 +129,14 @@ ignore_imports =
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
@@ -139,25 +145,33 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.variables.segments
|
||||
core.workflow.nodes.loop.entities -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.tool.tool_node -> models
|
||||
core.workflow.nodes.agent.agent_node -> models.model
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.node -> core.helper.code_executor
|
||||
core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.model_runtime.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.question_classifier.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
@@ -169,6 +183,54 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.errors
|
||||
core.workflow.conversation_variable_updater -> core.variables
|
||||
core.workflow.graph_engine.entities.commands -> core.variables.variables
|
||||
core.workflow.nodes.agent.agent_node -> core.variables.segments
|
||||
core.workflow.nodes.answer.answer_node -> core.variables
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
|
||||
core.workflow.nodes.list_operator.node -> core.variables
|
||||
core.workflow.nodes.list_operator.node -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.variables
|
||||
core.workflow.nodes.loop.loop_node -> core.variables
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.segments
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.variables
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.types
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
|
||||
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
|
||||
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
|
||||
core.workflow.nodes.variable_assigner.v1.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
|
||||
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
|
||||
core.workflow.runtime.read_only_wrappers -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables
|
||||
core.workflow.runtime.variable_pool -> core.variables.consts
|
||||
core.workflow.runtime.variable_pool -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables.variables
|
||||
core.workflow.utils.condition.processor -> core.variables
|
||||
core.workflow.utils.condition.processor -> core.variables.segments
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
@@ -185,17 +247,17 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
core.workflow.nodes.tool.tool_node -> services
|
||||
core.workflow.nodes.agent.agent_node -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.llm.node -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.token_buffer_memory
|
||||
|
||||
[importlinter:contract:model-runtime-no-internal-imports]
|
||||
name = Model Runtime Internal Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.model_runtime.callbacks
|
||||
core.model_runtime.entities
|
||||
core.model_runtime.errors
|
||||
core.model_runtime.model_providers
|
||||
core.model_runtime.schema_validators
|
||||
core.model_runtime.utils
|
||||
core.model_runtime
|
||||
forbidden_modules =
|
||||
configs
|
||||
controllers
|
||||
@@ -225,7 +287,7 @@ forbidden_modules =
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.model_runtime.prompt
|
||||
core.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
@@ -242,6 +304,13 @@ ignore_imports =
|
||||
core.model_runtime.model_providers.model_provider_factory -> configs
|
||||
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
core.model_runtime.token_buffer_memory -> core.app.app_config.features.file_upload.manager
|
||||
core.model_runtime.token_buffer_memory -> core.model_manager
|
||||
core.model_runtime.token_buffer_memory -> core.prompt.utils.extract_thread_messages
|
||||
core.model_runtime.token_buffer_memory -> core.workflow.file.file_manager
|
||||
core.model_runtime.token_buffer_memory -> extensions.ext_database
|
||||
core.model_runtime.token_buffer_memory -> models.model
|
||||
core.model_runtime.token_buffer_memory -> models.workflow
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -15,11 +15,11 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.workflow.variables.segment_group import SegmentGroup
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
@@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
@@ -36,9 +36,9 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.status == AccountStatus.UNINITIALIZED:
|
||||
@@ -214,9 +214,9 @@ def cloud_utm_record(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
def setup_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
|
||||
@@ -17,7 +17,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -32,7 +31,8 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
|
||||
@@ -17,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.file import file_manager
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app.app_config.entities import (
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.model_runtime.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
|
||||
@@ -25,6 +25,7 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -33,7 +34,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@@ -12,11 +12,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
@@ -22,7 +22,6 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -33,14 +32,11 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
|
||||
ChatModelMessage,
|
||||
CompletionModelPromptTemplate,
|
||||
MemoryConfig,
|
||||
)
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -11,9 +11,9 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file import File
|
||||
|
||||
@@ -49,6 +49,7 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
@@ -61,7 +62,6 @@ from core.workflow.enums import (
|
||||
from core.workflow.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -27,7 +27,7 @@ from core.app.entities.task_entities import (
|
||||
CompletionAppStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
@@ -20,7 +21,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.variables import VariableBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -83,21 +83,14 @@ def fetch_model_config(
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
@@ -105,6 +98,6 @@ def fetch_model_config(
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=completion_params,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
@@ -52,10 +52,10 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.workflow.file.enums import FileTransferMethod
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
@@ -29,11 +23,7 @@ from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import PromptMessageMemory
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
@@ -82,6 +72,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers()
|
||||
self._code_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
@@ -153,6 +144,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
@@ -179,8 +171,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@@ -188,8 +178,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@@ -220,7 +208,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@@ -228,11 +215,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@@ -240,7 +225,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
@@ -249,55 +233,3 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
|
||||
node_data_model = ModelConfig.model_validate(node_data["model"])
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance
|
||||
|
||||
def _build_memory_for_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_data: Mapping[str, Any],
|
||||
model_instance: ModelInstance,
|
||||
) -> PromptMessageMemory | None:
|
||||
raw_memory_config = node_data.get("memory")
|
||||
if raw_memory_config is None:
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
return llm_utils.fetch_memory(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
@@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
@@ -40,6 +40,12 @@ class CodeExecutionResponse(BaseModel):
|
||||
data: Data
|
||||
|
||||
|
||||
class CodeLanguage(StrEnum):
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
def _build_code_executor_client() -> httpx.Client:
|
||||
return httpx.Client(
|
||||
verify=CODE_EXECUTION_SSL_VERIFY,
|
||||
|
||||
@@ -5,7 +5,7 @@ from base64 import b64encode
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.variables.utils import dumps_with_segments
|
||||
from core.variables.utils import dumps_with_segments
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
|
||||
@@ -27,10 +27,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@@ -38,9 +38,6 @@ class ModelInstance:
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
self.load_balancing_manager = self._get_load_balancing_manager(
|
||||
configuration=provider_model_bundle.configuration,
|
||||
|
||||
@@ -83,21 +83,19 @@ def _merge_tool_call_delta(
|
||||
tool_call.function.arguments += delta.function.arguments
|
||||
|
||||
|
||||
def _build_llm_result_from_chunks(
|
||||
def _build_llm_result_from_first_chunk(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
chunks: Iterator[LLMResultChunk],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Build a single `LLMResult` by accumulating all returned chunks.
|
||||
Build a single `LLMResult` from the first returned chunk.
|
||||
|
||||
Some models only support streaming output (e.g. Qwen3 open-source edition)
|
||||
and the plugin side may still implement the response via a chunked stream,
|
||||
so all chunks must be consumed and concatenated into a single ``LLMResult``.
|
||||
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
|
||||
|
||||
The ``usage`` is taken from the last chunk that carries it, which is the
|
||||
typical convention for streaming responses (the final chunk contains the
|
||||
aggregated token counts).
|
||||
Note:
|
||||
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
|
||||
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
@@ -106,27 +104,24 @@ def _build_llm_result_from_chunks(
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
try:
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is not None:
|
||||
if isinstance(first_chunk.delta.message.content, str):
|
||||
content += first_chunk.delta.message.content
|
||||
elif isinstance(first_chunk.delta.message.content, list):
|
||||
content_list.extend(first_chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
if first_chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception:
|
||||
logger.exception("Error while consuming non-stream plugin chunk iterator.")
|
||||
raise
|
||||
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = first_chunk.system_fingerprint
|
||||
finally:
|
||||
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
|
||||
close = getattr(chunks, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
try:
|
||||
for _ in chunks:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
@@ -179,7 +174,7 @@ def _normalize_non_stream_plugin_result(
|
||||
) -> LLMResult:
|
||||
if isinstance(result, LLMResult):
|
||||
return result
|
||||
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.workflow.file import file_manager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -14,9 +14,10 @@ class BaseTraceInstance(ABC):
|
||||
Base trace instance for ops trace services
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, trace_config: BaseTracingConfig):
|
||||
"""
|
||||
Initializer for the trace instance.
|
||||
Abstract initializer for the trace instance.
|
||||
Distribute trace tasks by matching entities
|
||||
"""
|
||||
self.trace_config = trace_config
|
||||
|
||||
@@ -41,8 +41,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, key: str) -> dict[str, Any]:
|
||||
match key:
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
@@ -149,7 +149,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@@ -3,8 +3,6 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -14,13 +12,10 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
|
||||
ChatModelMessage,
|
||||
CompletionModelPromptTemplate,
|
||||
MemoryConfig,
|
||||
)
|
||||
from core.model_runtime.prompt.prompt_transform import PromptTransform
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file import file_manager
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
@@ -49,8 +44,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
@@ -65,7 +59,6 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
@@ -78,7 +71,6 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
@@ -93,8 +85,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
@@ -120,7 +111,6 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
parser=parser,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if query:
|
||||
@@ -156,8 +146,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
@@ -209,13 +198,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory,
|
||||
memory_config,
|
||||
prompt_messages,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
if files and query is not None:
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
@@ -292,8 +276,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
parser: PromptTemplateParser,
|
||||
prompt_inputs: Mapping[str, str],
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#histories#" in parser.variable_keys:
|
||||
@@ -303,11 +286,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token(
|
||||
[tmp_human_message],
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
@@ -3,14 +3,14 @@ from typing import cast
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.prompt_transform import PromptTransform
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
@@ -41,7 +41,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
if not self.memory:
|
||||
return prompt_messages
|
||||
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config)
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
||||
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@@ -1,86 +1,48 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def _resolve_model_runtime(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> tuple[ModelInstance, AIModelEntity]:
|
||||
if model_instance is None:
|
||||
if model_config is None:
|
||||
raise ValueError("Either model_config or model_instance must be provided.")
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
model_instance.credentials = model_config.credentials
|
||||
model_instance.parameters = model_config.parameters
|
||||
model_instance.stop = model_config.stop
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model=model_instance.model_name,
|
||||
credentials=model_instance.credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
if model_config is None:
|
||||
raise ValueError("Model schema not found for the provided model instance.")
|
||||
model_schema = model_config.model_schema
|
||||
|
||||
return model_instance, model_schema
|
||||
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(
|
||||
prompt_messages,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
model_instance, model_schema = self._resolve_model_runtime(
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
model_parameters = model_instance.parameters
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(parameter_rule.use_template or "")
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
@@ -15,9 +14,10 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.prompt.prompt_transform import PromptTransform
|
||||
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file import file_manager
|
||||
from models.model import AppMode
|
||||
|
||||
@@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
if memory:
|
||||
tmp_human_message = UserPromptMessage(content=prompt)
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config)
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
@@ -1,6 +1,6 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
|
||||
class PromptMessageUtil:
|
||||
@@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
@@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
@@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
|
||||
@@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@@ -75,15 +75,15 @@ class BaseIndexProcessor(ABC):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -115,7 +115,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
@@ -130,7 +130,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
@@ -196,7 +196,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
documents: list[Any] = []
|
||||
all_multimodal_documents: list[Any] = []
|
||||
if isinstance(chunks, list):
|
||||
@@ -469,7 +469,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if not isinstance(result, LLMResult):
|
||||
raise ValueError("Expected LLMResult when stream=False")
|
||||
|
||||
summary_content = result.message.get_text_content()
|
||||
summary_content = getattr(result.message, "content", "")
|
||||
usage = result.usage
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
|
||||
@@ -126,7 +126,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
for document in documents:
|
||||
@@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
@@ -272,7 +272,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_nodes.append(child_document)
|
||||
return child_nodes
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
|
||||
@@ -139,14 +139,14 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
@@ -206,7 +206,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
|
||||
@@ -23,18 +23,18 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
||||
@@ -5,8 +5,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
@@ -4,6 +4,8 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Field, Tag
|
||||
|
||||
from core.helper import encrypter
|
||||
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
@@ -25,14 +27,6 @@ from .segments import (
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
def _obfuscated_token(token: str) -> str:
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
return "*" * 20
|
||||
return token[:6] + "*" * 12 + token[-2:]
|
||||
|
||||
|
||||
class VariableBase(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
@@ -92,7 +86,7 @@ class SecretVariable(StringVariable):
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return _obfuscated_token(self.value)
|
||||
return encrypter.obfuscated_token(self.value)
|
||||
|
||||
|
||||
class NoneVariable(NoneSegment, VariableBase):
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.variables import VariableBase
|
||||
from core.variables import VariableBase
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@@ -76,10 +77,13 @@ class GraphEngine:
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
@@ -159,6 +163,7 @@ class GraphEngine:
|
||||
layers=self._layers,
|
||||
execution_context=execution_context,
|
||||
config=self._config,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@@ -189,6 +194,7 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@@ -308,6 +314,7 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
@@ -341,6 +348,7 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
@@ -44,6 +44,7 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -61,7 +62,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@@ -69,14 +70,12 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@@ -62,13 +63,16 @@ class Worker(threading.Thread):
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@@ -37,6 +37,7 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
config: GraphEngineConfig,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@@ -63,6 +64,7 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@@ -133,6 +135,7 @@ class WorkerPool:
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@@ -11,10 +11,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
@@ -25,6 +25,7 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -43,7 +44,6 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.variables import ArrayFileSegment, FileSegment, Segment
|
||||
|
||||
|
||||
class AnswerNode(Node[AnswerNodeData]):
|
||||
|
||||
@@ -302,6 +302,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@@ -370,6 +374,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast
|
||||
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.variables.segments import ArrayFileSegment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@@ -34,44 +36,12 @@ class WorkflowCodeExecutor(Protocol):
|
||||
def is_execution_error(self, error: Exception) -> bool: ...
|
||||
|
||||
|
||||
def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
],
|
||||
"code_language": language,
|
||||
"code": code,
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
|
||||
CodeLanguage.PYTHON3: dedent(
|
||||
"""
|
||||
def main(arg1: str, arg2: str):
|
||||
return {
|
||||
"result": arg1 + arg2,
|
||||
}
|
||||
"""
|
||||
),
|
||||
CodeLanguage.JAVASCRIPT: dedent(
|
||||
"""
|
||||
function main({arg1, arg2}) {
|
||||
return {
|
||||
result: arg1 + arg2
|
||||
}
|
||||
}
|
||||
"""
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
Python3CodeProvider,
|
||||
JavascriptCodeProvider,
|
||||
)
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
@@ -82,6 +52,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: WorkflowCodeExecutor,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -91,6 +62,9 @@ class CodeNode(Node[CodeNodeData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: WorkflowCodeExecutor = code_executor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
self._limits = code_limits
|
||||
|
||||
@classmethod
|
||||
@@ -104,10 +78,15 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language)
|
||||
if default_code is None:
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
return _build_default_config(language=code_language, code=default_code)
|
||||
code_provider: type[CodeNodeProvider] = next(
|
||||
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
|
||||
)
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
|
||||
return cls._DEFAULT_CODE_PROVIDERS
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@@ -129,6 +108,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute(
|
||||
language=code_language,
|
||||
code=code,
|
||||
@@ -150,6 +130,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
|
||||
for provider in self._code_providers:
|
||||
if provider.is_accept_language(code_language):
|
||||
return provider
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
|
||||
class CodeLanguage(StrEnum):
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
|
||||
@@ -21,12 +21,12 @@ from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File, FileTransferMethod, file_manager
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import ArrayFileSegment
|
||||
from core.workflow.variables.segments import ArrayStringSegment, FileSegment
|
||||
|
||||
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
@@ -10,9 +10,11 @@ from urllib.parse import urlencode, urlparse
|
||||
import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.file.enums import FileTransferMethod
|
||||
from core.workflow.file.file_manager import file_manager as default_file_manager
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
from ..protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from .entities import (
|
||||
@@ -79,8 +81,8 @@ class Executor:
|
||||
http_request_config: HttpRequestNodeConfig,
|
||||
max_retries: int | None = None,
|
||||
ssl_verify: bool | None = None,
|
||||
http_client: HttpClientProtocol,
|
||||
file_manager: FileManagerProtocol,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
):
|
||||
self._http_request_config = http_request_config
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
@@ -114,8 +116,8 @@ class Executor:
|
||||
self.max_retries = (
|
||||
max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries
|
||||
)
|
||||
self._http_client = http_client
|
||||
self._file_manager = file_manager
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
|
||||
@@ -3,15 +3,18 @@ import mimetypes
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File, FileTransferMethod
|
||||
from core.workflow.file.file_manager import file_manager as default_file_manager
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol
|
||||
from core.workflow.variables.segments import ArrayFileSegment
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from factories import file_factory
|
||||
|
||||
from .config import build_http_request_config, resolve_http_request_config
|
||||
@@ -42,9 +45,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
http_request_config: HttpRequestNodeConfig,
|
||||
http_client: HttpClientProtocol,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManagerProtocol],
|
||||
file_manager: FileManagerProtocol,
|
||||
http_client: HttpClientProtocol | None = None,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -52,11 +55,10 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
self._http_request_config = http_request_config
|
||||
self._http_client = http_client
|
||||
self._http_client = http_client or ssrf_proxy
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
self._file_manager = file_manager or default_file_manager
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
||||
@@ -10,10 +10,10 @@ from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
|
||||
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@@ -33,9 +36,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import IntegerVariable, NoneSegment
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.variables.variables import Variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .exc import (
|
||||
|
||||
@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
@@ -16,12 +22,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.variables.segments import ArrayObjectSegment
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
@@ -4,11 +4,7 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
|
||||
ChatModelMessage,
|
||||
CompletionModelPromptTemplate,
|
||||
MemoryConfig,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@@ -5,17 +5,21 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
@@ -25,14 +29,46 @@ from models.provider_ids import ModelProviderID
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
return model_schema
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
@@ -36,12 +37,21 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
@@ -62,16 +72,8 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
@@ -81,6 +83,7 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -113,8 +116,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -125,8 +126,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -140,8 +139,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@@ -205,12 +202,29 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||
|
||||
# fetch model config
|
||||
model_instance = self._model_instance
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
model_instance, model_config = self._fetch_model_config(
|
||||
node_data_model=self.node_data.model,
|
||||
)
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if not isinstance(model_name, str):
|
||||
model_name = model_config.model
|
||||
model_provider = getattr(model_instance, "provider", None)
|
||||
if not isinstance(model_provider, str):
|
||||
model_provider = model_config.provider
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_name}")
|
||||
|
||||
memory = self._memory
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
if self.node_data.memory:
|
||||
@@ -226,7 +240,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
stop=model_stop,
|
||||
model_schema=model_schema,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
@@ -238,6 +254,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@@ -354,6 +371,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def invoke_llm(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None = None,
|
||||
@@ -366,10 +384,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_parameters = model_instance.parameters
|
||||
invoke_model_parameters = dict(model_parameters)
|
||||
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
@@ -383,7 +402,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=invoke_model_parameters,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@@ -393,7 +412,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=invoke_model_parameters,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@@ -752,14 +771,33 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(
|
||||
self,
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
return model, model_config_with_cred
|
||||
|
||||
@staticmethod
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -770,7 +808,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
@@ -789,6 +826,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
@@ -826,6 +865,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
@@ -1275,23 +1316,23 @@ def _calculate_rest_token(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
@@ -1303,9 +1344,11 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
@@ -1313,6 +1356,8 @@ def _handle_memory_chat_mode(
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
@@ -1323,9 +1368,11 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
@@ -1333,51 +1380,20 @@ def _handle_memory_completion_mode(
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
memory_text = _convert_history_messages_to_text(
|
||||
history_messages=memory_messages,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
@@ -21,13 +19,3 @@ class ModelFactory(Protocol):
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages for LLM nodes."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -30,7 +31,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from core.workflow.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ from pydantic import (
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -19,19 +19,20 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.types import ArrayValidation, SegmentType
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
@@ -94,7 +95,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: ModelInstance
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
|
||||
@@ -107,7 +109,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -117,7 +118,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@@ -155,14 +155,18 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
else []
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@@ -180,7 +184,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
@@ -191,7 +195,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
@@ -207,23 +211,24 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
}
|
||||
|
||||
process_data = {
|
||||
"model_mode": node_data.model.mode,
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": None,
|
||||
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
"tool_call": None,
|
||||
"model_provider": model_instance.provider,
|
||||
"model_name": model_instance.model_name,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
try:
|
||||
text, usage, tool_call = self._invoke(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_message_tools,
|
||||
stop=model_instance.stop,
|
||||
stop=model_config.stop,
|
||||
)
|
||||
process_data["usage"] = jsonable_encoder(usage)
|
||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||
@@ -285,16 +290,17 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: Sequence[str],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=dict(model_instance.parameters),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
tools=tools,
|
||||
stop=list(stop),
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
)
|
||||
@@ -318,7 +324,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -331,13 +337,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
@@ -349,7 +349,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -406,7 +406,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -421,7 +421,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
@@ -431,7 +431,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
@@ -444,7 +444,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -454,11 +454,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
@@ -471,7 +467,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -482,7 +478,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -492,11 +488,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data,
|
||||
@@ -516,7 +508,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -777,16 +769,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
@@ -799,28 +796,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = (
|
||||
model_type_instance.get_num_tokens(
|
||||
model_instance.model_name, model_instance.credentials, prompt_messages
|
||||
)
|
||||
+ 1000
|
||||
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
|
||||
) # add 1000 to ensure tool call messages
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_instance.parameters.get(parameter_rule.name)
|
||||
or model_instance.parameters.get(parameter_rule.use_template or "")
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
@@ -828,6 +824,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config.
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -27,16 +27,3 @@ class HttpClientProtocol(Protocol):
|
||||
|
||||
class FileManagerProtocol(Protocol):
|
||||
def download(self, f: File, /) -> bytes: ...
|
||||
|
||||
|
||||
class ToolFileManagerProtocol(Protocol):
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@ import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
|
||||
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@@ -20,12 +22,7 @@ from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
@@ -55,7 +52,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -66,7 +62,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -80,7 +75,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@@ -101,8 +95,18 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
|
||||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# fetch model config
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data.model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@@ -127,7 +131,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_template(
|
||||
@@ -145,7 +149,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
stop=model_instance.stop,
|
||||
model_schema=model_schema,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
@@ -160,6 +166,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@@ -198,14 +205,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
category_name = classes_map[category_id_result]
|
||||
category_id = category_id_result
|
||||
process_data = {
|
||||
"model_mode": node_data.model.mode,
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"model_provider": model_instance.provider,
|
||||
"model_name": model_instance.model_name,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
@@ -278,40 +285,39 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
sys_files=[],
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
stop=model_instance.stop,
|
||||
memory_config=node_data.memory,
|
||||
vision_enabled=False,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_instance.parameters.get(parameter_rule.name)
|
||||
or model_instance.parameters.get(parameter_rule.use_template or "")
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
|
||||
@@ -11,6 +11,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -21,8 +23,6 @@ from core.workflow.file import File, FileTransferMethod
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.workflow.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import FileVariable
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
from core.workflow.file import FileTransferMethod
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.workflow.variables.variables import FileVariable
|
||||
from factories import file_factory
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
|
||||
class AdvancedSettings(BaseModel):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData
|
||||
from core.workflow.variables.segments import Segment
|
||||
|
||||
|
||||
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.variables import Segment
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.variables import Segment
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
# Use double underscore (`__`) prefix for internal variables
|
||||
# to minimize risk of collision with user-defined variable names.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user