mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 19:59:21 +08:00
refactor(graph_engine): Add a Config class for graph engine. (#31663)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -105,7 +105,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||||
core.workflow.graph_engine.worker_management.worker_pool -> configs
|
|
||||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from .config import GraphEngineConfig
|
||||||
from .graph_engine import GraphEngine
|
from .graph_engine import GraphEngine
|
||||||
|
|
||||||
__all__ = ["GraphEngine"]
|
__all__ = ["GraphEngine", "GraphEngineConfig"]
|
||||||
|
|||||||
14
api/core/workflow/graph_engine/config.py
Normal file
14
api/core/workflow/graph_engine/config.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
GraphEngine configuration models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEngineConfig(BaseModel):
|
||||||
|
"""Configuration for GraphEngine worker pool scaling."""
|
||||||
|
|
||||||
|
min_workers: int = 1
|
||||||
|
max_workers: int = 5
|
||||||
|
scale_up_threshold: int = 3
|
||||||
|
scale_down_idle_time: float = 5.0
|
||||||
@@ -37,6 +37,7 @@ from .command_processing import (
|
|||||||
PauseCommandHandler,
|
PauseCommandHandler,
|
||||||
UpdateVariablesCommandHandler,
|
UpdateVariablesCommandHandler,
|
||||||
)
|
)
|
||||||
|
from .config import GraphEngineConfig
|
||||||
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
||||||
from .error_handler import ErrorHandler
|
from .error_handler import ErrorHandler
|
||||||
from .event_management import EventHandler, EventManager
|
from .event_management import EventHandler, EventManager
|
||||||
@@ -70,10 +71,7 @@ class GraphEngine:
|
|||||||
graph: Graph,
|
graph: Graph,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
command_channel: CommandChannel,
|
command_channel: CommandChannel,
|
||||||
min_workers: int | None = None,
|
config: GraphEngineConfig,
|
||||||
max_workers: int | None = None,
|
|
||||||
scale_up_threshold: int | None = None,
|
|
||||||
scale_down_idle_time: float | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||||
# stop event
|
# stop event
|
||||||
@@ -85,18 +83,12 @@ class GraphEngine:
|
|||||||
self._graph_runtime_state.stop_event = self._stop_event
|
self._graph_runtime_state.stop_event = self._stop_event
|
||||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||||
self._command_channel = command_channel
|
self._command_channel = command_channel
|
||||||
|
self._config = config
|
||||||
|
|
||||||
# Graph execution tracks the overall execution state
|
# Graph execution tracks the overall execution state
|
||||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||||
self._graph_execution.workflow_id = workflow_id
|
self._graph_execution.workflow_id = workflow_id
|
||||||
|
|
||||||
# === Worker Management Parameters ===
|
|
||||||
# Parameters for dynamic worker pool scaling
|
|
||||||
self._min_workers = min_workers
|
|
||||||
self._max_workers = max_workers
|
|
||||||
self._scale_up_threshold = scale_up_threshold
|
|
||||||
self._scale_down_idle_time = scale_down_idle_time
|
|
||||||
|
|
||||||
# === Execution Queues ===
|
# === Execution Queues ===
|
||||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||||
|
|
||||||
@@ -167,10 +159,7 @@ class GraphEngine:
|
|||||||
graph=self._graph,
|
graph=self._graph,
|
||||||
layers=self._layers,
|
layers=self._layers,
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
min_workers=self._min_workers,
|
config=self._config,
|
||||||
max_workers=self._max_workers,
|
|
||||||
scale_up_threshold=self._scale_up_threshold,
|
|
||||||
scale_down_idle_time=self._scale_down_idle_time,
|
|
||||||
stop_event=self._stop_event,
|
stop_event=self._stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ import queue
|
|||||||
import threading
|
import threading
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.workflow.context import IExecutionContext
|
from core.workflow.context import IExecutionContext
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_events import GraphNodeEventBase
|
from core.workflow.graph_events import GraphNodeEventBase
|
||||||
|
|
||||||
|
from ..config import GraphEngineConfig
|
||||||
from ..layers.base import GraphEngineLayer
|
from ..layers.base import GraphEngineLayer
|
||||||
from ..ready_queue import ReadyQueue
|
from ..ready_queue import ReadyQueue
|
||||||
from ..worker import Worker
|
from ..worker import Worker
|
||||||
@@ -38,11 +38,8 @@ class WorkerPool:
|
|||||||
graph: Graph,
|
graph: Graph,
|
||||||
layers: list[GraphEngineLayer],
|
layers: list[GraphEngineLayer],
|
||||||
stop_event: threading.Event,
|
stop_event: threading.Event,
|
||||||
|
config: GraphEngineConfig,
|
||||||
execution_context: IExecutionContext | None = None,
|
execution_context: IExecutionContext | None = None,
|
||||||
min_workers: int | None = None,
|
|
||||||
max_workers: int | None = None,
|
|
||||||
scale_up_threshold: int | None = None,
|
|
||||||
scale_down_idle_time: float | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the simple worker pool.
|
Initialize the simple worker pool.
|
||||||
@@ -52,23 +49,15 @@ class WorkerPool:
|
|||||||
event_queue: Queue for worker events
|
event_queue: Queue for worker events
|
||||||
graph: The workflow graph
|
graph: The workflow graph
|
||||||
layers: Graph engine layers for node execution hooks
|
layers: Graph engine layers for node execution hooks
|
||||||
|
config: GraphEngine worker pool configuration
|
||||||
execution_context: Optional execution context for context preservation
|
execution_context: Optional execution context for context preservation
|
||||||
min_workers: Minimum number of workers
|
|
||||||
max_workers: Maximum number of workers
|
|
||||||
scale_up_threshold: Queue depth to trigger scale up
|
|
||||||
scale_down_idle_time: Seconds before scaling down idle workers
|
|
||||||
"""
|
"""
|
||||||
self._ready_queue = ready_queue
|
self._ready_queue = ready_queue
|
||||||
self._event_queue = event_queue
|
self._event_queue = event_queue
|
||||||
self._graph = graph
|
self._graph = graph
|
||||||
self._execution_context = execution_context
|
self._execution_context = execution_context
|
||||||
self._layers = layers
|
self._layers = layers
|
||||||
|
self._config = config
|
||||||
# Scaling parameters with defaults
|
|
||||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
|
||||||
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
|
||||||
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
|
||||||
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
|
||||||
|
|
||||||
# Worker management
|
# Worker management
|
||||||
self._workers: list[Worker] = []
|
self._workers: list[Worker] = []
|
||||||
@@ -96,18 +85,18 @@ class WorkerPool:
|
|||||||
if initial_count is None:
|
if initial_count is None:
|
||||||
node_count = len(self._graph.nodes)
|
node_count = len(self._graph.nodes)
|
||||||
if node_count < 10:
|
if node_count < 10:
|
||||||
initial_count = self._min_workers
|
initial_count = self._config.min_workers
|
||||||
elif node_count < 50:
|
elif node_count < 50:
|
||||||
initial_count = min(self._min_workers + 1, self._max_workers)
|
initial_count = min(self._config.min_workers + 1, self._config.max_workers)
|
||||||
else:
|
else:
|
||||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
initial_count = min(self._config.min_workers + 2, self._config.max_workers)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
||||||
initial_count,
|
initial_count,
|
||||||
node_count,
|
node_count,
|
||||||
self._min_workers,
|
self._config.min_workers,
|
||||||
self._max_workers,
|
self._config.max_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create initial workers
|
# Create initial workers
|
||||||
@@ -176,7 +165,7 @@ class WorkerPool:
|
|||||||
Returns:
|
Returns:
|
||||||
True if scaled up, False otherwise
|
True if scaled up, False otherwise
|
||||||
"""
|
"""
|
||||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers:
|
||||||
old_count = current_count
|
old_count = current_count
|
||||||
self._create_worker()
|
self._create_worker()
|
||||||
|
|
||||||
@@ -185,7 +174,7 @@ class WorkerPool:
|
|||||||
old_count,
|
old_count,
|
||||||
len(self._workers),
|
len(self._workers),
|
||||||
queue_depth,
|
queue_depth,
|
||||||
self._scale_up_threshold,
|
self._config.scale_up_threshold,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@@ -204,7 +193,7 @@ class WorkerPool:
|
|||||||
True if scaled down, False otherwise
|
True if scaled down, False otherwise
|
||||||
"""
|
"""
|
||||||
# Skip if we're at minimum or have no idle workers
|
# Skip if we're at minimum or have no idle workers
|
||||||
if current_count <= self._min_workers or idle_count == 0:
|
if current_count <= self._config.min_workers or idle_count == 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if we have excess capacity
|
# Check if we have excess capacity
|
||||||
@@ -222,10 +211,10 @@ class WorkerPool:
|
|||||||
|
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
# Check if worker is idle and has exceeded idle time threshold
|
# Check if worker is idle and has exceeded idle time threshold
|
||||||
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
|
if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time:
|
||||||
# Don't remove if it would leave us unable to handle the queue
|
# Don't remove if it would leave us unable to handle the queue
|
||||||
remaining_workers = current_count - len(workers_to_remove) - 1
|
remaining_workers = current_count - len(workers_to_remove) - 1
|
||||||
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
|
if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2):
|
||||||
workers_to_remove.append((worker, worker.worker_id))
|
workers_to_remove.append((worker, worker.worker_id))
|
||||||
# Only remove one worker per check to avoid aggressive scaling
|
# Only remove one worker per check to avoid aggressive scaling
|
||||||
break
|
break
|
||||||
@@ -242,7 +231,7 @@ class WorkerPool:
|
|||||||
old_count,
|
old_count,
|
||||||
len(self._workers),
|
len(self._workers),
|
||||||
len(workers_to_remove),
|
len(workers_to_remove),
|
||||||
self._scale_down_idle_time,
|
self._config.scale_down_idle_time,
|
||||||
queue_depth,
|
queue_depth,
|
||||||
active_count,
|
active_count,
|
||||||
idle_count - len(workers_to_remove),
|
idle_count - len(workers_to_remove),
|
||||||
@@ -286,6 +275,6 @@ class WorkerPool:
|
|||||||
return {
|
return {
|
||||||
"total_workers": len(self._workers),
|
"total_workers": len(self._workers),
|
||||||
"queue_depth": self._ready_queue.qsize(),
|
"queue_depth": self._ready_queue.qsize(),
|
||||||
"min_workers": self._min_workers,
|
"min_workers": self._config.min_workers,
|
||||||
"max_workers": self._max_workers,
|
"max_workers": self._config.max_workers,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -591,7 +591,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
@@ -640,6 +640,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||||||
graph=iteration_graph,
|
graph=iteration_graph,
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_engine
|
return graph_engine
|
||||||
|
|||||||
@@ -416,7 +416,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
@@ -452,6 +452,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
graph=loop_graph,
|
graph=loop_graph,
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_engine
|
return graph_engine
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
|||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||||
@@ -81,6 +81,12 @@ class WorkflowEntry:
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
|
config=GraphEngineConfig(
|
||||||
|
min_workers=dify_config.GRAPH_ENGINE_MIN_WORKERS,
|
||||||
|
max_workers=dify_config.GRAPH_ENGINE_MAX_WORKERS,
|
||||||
|
scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD,
|
||||||
|
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add debug logging layer when in debug mode
|
# Add debug logging layer when in debug mode
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_engine.layers.base import (
|
from core.workflow.graph_engine.layers.base import (
|
||||||
GraphEngineLayer,
|
GraphEngineLayer,
|
||||||
@@ -43,6 +43,7 @@ def test_layer_runtime_state_available_after_engine_layer() -> None:
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
layer = LayerForTest()
|
layer = LayerForTest()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from core.variables import IntegerVariable, StringVariable
|
|||||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.entities.pause_reason import SchedulingPause
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_engine.entities.commands import (
|
from core.workflow.graph_engine.entities.commands import (
|
||||||
AbortCommand,
|
AbortCommand,
|
||||||
@@ -67,6 +67,7 @@ def test_abort_command():
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=shared_runtime_state, # Use shared instance
|
graph_runtime_state=shared_runtime_state, # Use shared instance
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send abort command before starting
|
# Send abort command before starting
|
||||||
@@ -173,6 +174,7 @@ def test_pause_command():
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=shared_runtime_state,
|
graph_runtime_state=shared_runtime_state,
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
pause_command = PauseCommand(reason="User requested pause")
|
pause_command = PauseCommand(reason="User requested pause")
|
||||||
@@ -228,6 +230,7 @@ def test_update_variables_command_updates_pool():
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=shared_runtime_state,
|
graph_runtime_state=shared_runtime_state,
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
update_command = UpdateVariablesCommand(
|
update_command = UpdateVariablesCommand(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ This test validates that:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
@@ -44,6 +44,7 @@ def test_streaming_output_with_blocking_equals_one():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the workflow
|
# Execute the workflow
|
||||||
@@ -139,6 +140,7 @@ def test_streaming_output_with_blocking_not_equals_one():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the workflow
|
# Execute the workflow
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from hypothesis import HealthCheck, given, settings
|
|||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy
|
from core.workflow.enums import ErrorStrategy
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
@@ -469,6 +469,7 @@ def test_layer_system_basic():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add debug logging layer
|
# Add debug logging layer
|
||||||
@@ -525,6 +526,7 @@ def test_layer_chaining():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Chain multiple layers
|
# Chain multiple layers
|
||||||
@@ -572,6 +574,7 @@ def test_layer_error_handling():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add faulty layer
|
# Add faulty layer
|
||||||
@@ -753,6 +756,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
events = list(engine.run())
|
events = list(engine.run())
|
||||||
|
|||||||
@@ -566,7 +566,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
|||||||
# Import dependencies
|
# Import dependencies
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
@@ -623,6 +623,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
|||||||
graph=iteration_graph,
|
graph=iteration_graph,
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_engine
|
return graph_engine
|
||||||
@@ -641,7 +642,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
|||||||
# Import dependencies
|
# Import dependencies
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
@@ -685,6 +686,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
|||||||
graph=loop_graph,
|
graph=loop_graph,
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_engine
|
return graph_engine
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from core.app.workflow.node_factory import DifyNodeFactory
|
|||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
@@ -123,6 +123,7 @@ def test_parallel_streaming_workflow():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define LLM outputs
|
# Define LLM outputs
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
@@ -41,6 +41,7 @@ class TestStopEventPropagation:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify stop_event was created
|
# Verify stop_event was created
|
||||||
@@ -84,6 +85,7 @@ class TestStopEventPropagation:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the stop_event before running
|
# Set the stop_event before running
|
||||||
@@ -131,6 +133,7 @@ class TestStopEventPropagation:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initially not set
|
# Initially not set
|
||||||
@@ -155,6 +158,7 @@ class TestStopEventPropagation:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify WorkerPool has the stop_event
|
# Verify WorkerPool has the stop_event
|
||||||
@@ -174,6 +178,7 @@ class TestStopEventPropagation:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify Dispatcher has the stop_event
|
# Verify Dispatcher has the stop_event
|
||||||
@@ -311,6 +316,7 @@ class TestStopEventIntegration:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set stop_event before running
|
# Set stop_event before running
|
||||||
@@ -360,6 +366,7 @@ class TestStopEventIntegration:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# All nodes should share the same stop_event
|
# All nodes should share the same stop_event
|
||||||
@@ -385,6 +392,7 @@ class TestStopEventTimeoutBehavior:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
dispatcher = engine._dispatcher
|
dispatcher = engine._dispatcher
|
||||||
@@ -411,6 +419,7 @@ class TestStopEventTimeoutBehavior:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_pool = engine._worker_pool
|
worker_pool = engine._worker_pool
|
||||||
@@ -460,6 +469,7 @@ class TestStopEventResumeBehavior:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simulate a previous execution that set stop_event
|
# Simulate a previous execution that set stop_event
|
||||||
@@ -490,6 +500,7 @@ class TestWorkerStopBehavior:
|
|||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_runtime_state=runtime_state,
|
graph_runtime_state=runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the worker pool and check workers
|
# Get the worker pool and check workers
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from core.variables import (
|
|||||||
)
|
)
|
||||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
@@ -309,10 +309,12 @@ class TableTestRunner:
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
min_workers=self.graph_engine_min_workers,
|
config=GraphEngineConfig(
|
||||||
max_workers=self.graph_engine_max_workers,
|
min_workers=self.graph_engine_min_workers,
|
||||||
scale_up_threshold=self.graph_engine_scale_up_threshold,
|
max_workers=self.graph_engine_max_workers,
|
||||||
scale_down_idle_time=self.graph_engine_scale_down_idle_time,
|
scale_up_threshold=self.graph_engine_scale_up_threshold,
|
||||||
|
scale_down_idle_time=self.graph_engine_scale_down_idle_time,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute and collect events
|
# Execute and collect events
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
@@ -27,6 +27,7 @@ def test_tool_in_chatflow():
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
events = list(engine.run())
|
events = list(engine.run())
|
||||||
|
|||||||
Reference in New Issue
Block a user