refactor(graph_engine): Add a Config class for graph engine. (#31663)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-01-28 19:57:55 +08:00
committed by GitHub
parent 7f40f178ed
commit 24ebe2f5c6
17 changed files with 89 additions and 62 deletions

View File

@@ -105,7 +105,6 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability core.workflow.workflow_entry -> core.app.workflow.layers.observability
core.workflow.graph_engine.worker_management.worker_pool -> configs
core.workflow.nodes.agent.agent_node -> core.model_manager core.workflow.nodes.agent.agent_node -> core.model_manager
core.workflow.nodes.agent.agent_node -> core.provider_manager core.workflow.nodes.agent.agent_node -> core.provider_manager
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager core.workflow.nodes.agent.agent_node -> core.tools.tool_manager

View File

@@ -1,3 +1,4 @@
from .config import GraphEngineConfig
from .graph_engine import GraphEngine from .graph_engine import GraphEngine
__all__ = ["GraphEngine"] __all__ = ["GraphEngine", "GraphEngineConfig"]

View File

@@ -0,0 +1,14 @@
"""
GraphEngine configuration models.
"""
from pydantic import BaseModel
class GraphEngineConfig(BaseModel):
"""Configuration for GraphEngine worker pool scaling."""
min_workers: int = 1
max_workers: int = 5
scale_up_threshold: int = 3
scale_down_idle_time: float = 5.0

View File

@@ -37,6 +37,7 @@ from .command_processing import (
PauseCommandHandler, PauseCommandHandler,
UpdateVariablesCommandHandler, UpdateVariablesCommandHandler,
) )
from .config import GraphEngineConfig
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager from .event_management import EventHandler, EventManager
@@ -70,10 +71,7 @@ class GraphEngine:
graph: Graph, graph: Graph,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel, command_channel: CommandChannel,
min_workers: int | None = None, config: GraphEngineConfig,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None: ) -> None:
"""Initialize the graph engine with all subsystems and dependencies.""" """Initialize the graph engine with all subsystems and dependencies."""
# stop event # stop event
@@ -85,18 +83,12 @@ class GraphEngine:
self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel self._command_channel = command_channel
self._config = config
# Graph execution tracks the overall execution state # Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
self._graph_execution.workflow_id = workflow_id self._graph_execution.workflow_id = workflow_id
# === Worker Management Parameters ===
# Parameters for dynamic worker pool scaling
self._min_workers = min_workers
self._max_workers = max_workers
self._scale_up_threshold = scale_up_threshold
self._scale_down_idle_time = scale_down_idle_time
# === Execution Queues === # === Execution Queues ===
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
@@ -167,10 +159,7 @@ class GraphEngine:
graph=self._graph, graph=self._graph,
layers=self._layers, layers=self._layers,
execution_context=execution_context, execution_context=execution_context,
min_workers=self._min_workers, config=self._config,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
stop_event=self._stop_event, stop_event=self._stop_event,
) )

View File

@@ -10,11 +10,11 @@ import queue
import threading import threading
from typing import final from typing import final
from configs import dify_config
from core.workflow.context import IExecutionContext from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase from core.workflow.graph_events import GraphNodeEventBase
from ..config import GraphEngineConfig
from ..layers.base import GraphEngineLayer from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue from ..ready_queue import ReadyQueue
from ..worker import Worker from ..worker import Worker
@@ -38,11 +38,8 @@ class WorkerPool:
graph: Graph, graph: Graph,
layers: list[GraphEngineLayer], layers: list[GraphEngineLayer],
stop_event: threading.Event, stop_event: threading.Event,
config: GraphEngineConfig,
execution_context: IExecutionContext | None = None, execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None: ) -> None:
""" """
Initialize the simple worker pool. Initialize the simple worker pool.
@@ -52,23 +49,15 @@ class WorkerPool:
event_queue: Queue for worker events event_queue: Queue for worker events
graph: The workflow graph graph: The workflow graph
layers: Graph engine layers for node execution hooks layers: Graph engine layers for node execution hooks
config: GraphEngine worker pool configuration
execution_context: Optional execution context for context preservation execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
""" """
self._ready_queue = ready_queue self._ready_queue = ready_queue
self._event_queue = event_queue self._event_queue = event_queue
self._graph = graph self._graph = graph
self._execution_context = execution_context self._execution_context = execution_context
self._layers = layers self._layers = layers
self._config = config
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
# Worker management # Worker management
self._workers: list[Worker] = [] self._workers: list[Worker] = []
@@ -96,18 +85,18 @@ class WorkerPool:
if initial_count is None: if initial_count is None:
node_count = len(self._graph.nodes) node_count = len(self._graph.nodes)
if node_count < 10: if node_count < 10:
initial_count = self._min_workers initial_count = self._config.min_workers
elif node_count < 50: elif node_count < 50:
initial_count = min(self._min_workers + 1, self._max_workers) initial_count = min(self._config.min_workers + 1, self._config.max_workers)
else: else:
initial_count = min(self._min_workers + 2, self._max_workers) initial_count = min(self._config.min_workers + 2, self._config.max_workers)
logger.debug( logger.debug(
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
initial_count, initial_count,
node_count, node_count,
self._min_workers, self._config.min_workers,
self._max_workers, self._config.max_workers,
) )
# Create initial workers # Create initial workers
@@ -176,7 +165,7 @@ class WorkerPool:
Returns: Returns:
True if scaled up, False otherwise True if scaled up, False otherwise
""" """
if queue_depth > self._scale_up_threshold and current_count < self._max_workers: if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers:
old_count = current_count old_count = current_count
self._create_worker() self._create_worker()
@@ -185,7 +174,7 @@ class WorkerPool:
old_count, old_count,
len(self._workers), len(self._workers),
queue_depth, queue_depth,
self._scale_up_threshold, self._config.scale_up_threshold,
) )
return True return True
return False return False
@@ -204,7 +193,7 @@ class WorkerPool:
True if scaled down, False otherwise True if scaled down, False otherwise
""" """
# Skip if we're at minimum or have no idle workers # Skip if we're at minimum or have no idle workers
if current_count <= self._min_workers or idle_count == 0: if current_count <= self._config.min_workers or idle_count == 0:
return False return False
# Check if we have excess capacity # Check if we have excess capacity
@@ -222,10 +211,10 @@ class WorkerPool:
for worker in self._workers: for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold # Check if worker is idle and has exceeded idle time threshold
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time: if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time:
# Don't remove if it would leave us unable to handle the queue # Don't remove if it would leave us unable to handle the queue
remaining_workers = current_count - len(workers_to_remove) - 1 remaining_workers = current_count - len(workers_to_remove) - 1
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2): if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2):
workers_to_remove.append((worker, worker.worker_id)) workers_to_remove.append((worker, worker.worker_id))
# Only remove one worker per check to avoid aggressive scaling # Only remove one worker per check to avoid aggressive scaling
break break
@@ -242,7 +231,7 @@ class WorkerPool:
old_count, old_count,
len(self._workers), len(self._workers),
len(workers_to_remove), len(workers_to_remove),
self._scale_down_idle_time, self._config.scale_down_idle_time,
queue_depth, queue_depth,
active_count, active_count,
idle_count - len(workers_to_remove), idle_count - len(workers_to_remove),
@@ -286,6 +275,6 @@ class WorkerPool:
return { return {
"total_workers": len(self._workers), "total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(), "queue_depth": self._ready_queue.qsize(),
"min_workers": self._min_workers, "min_workers": self._config.min_workers,
"max_workers": self._max_workers, "max_workers": self._config.max_workers,
} }

View File

@@ -591,7 +591,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
from core.app.workflow.node_factory import DifyNodeFactory from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState from core.workflow.runtime import GraphRuntimeState
@@ -640,6 +640,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph=iteration_graph, graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy, graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
) )
return graph_engine return graph_engine

View File

@@ -416,7 +416,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
from core.app.workflow.node_factory import DifyNodeFactory from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState from core.workflow.runtime import GraphRuntimeState
@@ -452,6 +452,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
graph=loop_graph, graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy, graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
) )
return graph_engine return graph_engine

View File

@@ -14,7 +14,7 @@ from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_engine.protocols.command_channel import CommandChannel
@@ -81,6 +81,12 @@ class WorkflowEntry:
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=command_channel, command_channel=command_channel,
config=GraphEngineConfig(
min_workers=dify_config.GRAPH_ENGINE_MIN_WORKERS,
max_workers=dify_config.GRAPH_ENGINE_MAX_WORKERS,
scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD,
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
),
) )
# Add debug logging layer when in debug mode # Add debug logging layer when in debug mode

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import pytest import pytest
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers.base import ( from core.workflow.graph_engine.layers.base import (
GraphEngineLayer, GraphEngineLayer,
@@ -43,6 +43,7 @@ def test_layer_runtime_state_available_after_engine_layer() -> None:
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
layer = LayerForTest() layer = LayerForTest()

View File

@@ -8,7 +8,7 @@ from core.variables import IntegerVariable, StringVariable
from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import ( from core.workflow.graph_engine.entities.commands import (
AbortCommand, AbortCommand,
@@ -67,6 +67,7 @@ def test_abort_command():
graph=mock_graph, graph=mock_graph,
graph_runtime_state=shared_runtime_state, # Use shared instance graph_runtime_state=shared_runtime_state, # Use shared instance
command_channel=command_channel, command_channel=command_channel,
config=GraphEngineConfig(),
) )
# Send abort command before starting # Send abort command before starting
@@ -173,6 +174,7 @@ def test_pause_command():
graph=mock_graph, graph=mock_graph,
graph_runtime_state=shared_runtime_state, graph_runtime_state=shared_runtime_state,
command_channel=command_channel, command_channel=command_channel,
config=GraphEngineConfig(),
) )
pause_command = PauseCommand(reason="User requested pause") pause_command = PauseCommand(reason="User requested pause")
@@ -228,6 +230,7 @@ def test_update_variables_command_updates_pool():
graph=mock_graph, graph=mock_graph,
graph_runtime_state=shared_runtime_state, graph_runtime_state=shared_runtime_state,
command_channel=command_channel, command_channel=command_channel,
config=GraphEngineConfig(),
) )
update_command = UpdateVariablesCommand( update_command = UpdateVariablesCommand(

View File

@@ -7,7 +7,7 @@ This test validates that:
""" """
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphRunSucceededEvent, GraphRunSucceededEvent,
@@ -44,6 +44,7 @@ def test_streaming_output_with_blocking_equals_one():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Execute the workflow # Execute the workflow
@@ -139,6 +140,7 @@ def test_streaming_output_with_blocking_not_equals_one():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Execute the workflow # Execute the workflow

View File

@@ -11,7 +11,7 @@ from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st from hypothesis import strategies as st
from core.workflow.enums import ErrorStrategy from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
@@ -469,6 +469,7 @@ def test_layer_system_basic():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Add debug logging layer # Add debug logging layer
@@ -525,6 +526,7 @@ def test_layer_chaining():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Chain multiple layers # Chain multiple layers
@@ -572,6 +574,7 @@ def test_layer_error_handling():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Add faulty layer # Add faulty layer
@@ -753,6 +756,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
events = list(engine.run()) events = list(engine.run())

View File

@@ -566,7 +566,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
# Import dependencies # Import dependencies
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState from core.workflow.runtime import GraphRuntimeState
@@ -623,6 +623,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
graph=iteration_graph, graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy, graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
) )
return graph_engine return graph_engine
@@ -641,7 +642,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
# Import dependencies # Import dependencies
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState from core.workflow.runtime import GraphRuntimeState
@@ -685,6 +686,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
graph=loop_graph, graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy, graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
) )
return graph_engine return graph_engine

View File

@@ -17,7 +17,7 @@ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphRunSucceededEvent, GraphRunSucceededEvent,
@@ -123,6 +123,7 @@ def test_parallel_streaming_workflow():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Define LLM outputs # Define LLM outputs

View File

@@ -12,7 +12,7 @@ from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphRunStartedEvent, GraphRunStartedEvent,
@@ -41,6 +41,7 @@ class TestStopEventPropagation:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Verify stop_event was created # Verify stop_event was created
@@ -84,6 +85,7 @@ class TestStopEventPropagation:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Set the stop_event before running # Set the stop_event before running
@@ -131,6 +133,7 @@ class TestStopEventPropagation:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Initially not set # Initially not set
@@ -155,6 +158,7 @@ class TestStopEventPropagation:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Verify WorkerPool has the stop_event # Verify WorkerPool has the stop_event
@@ -174,6 +178,7 @@ class TestStopEventPropagation:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Verify Dispatcher has the stop_event # Verify Dispatcher has the stop_event
@@ -311,6 +316,7 @@ class TestStopEventIntegration:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Set stop_event before running # Set stop_event before running
@@ -360,6 +366,7 @@ class TestStopEventIntegration:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# All nodes should share the same stop_event # All nodes should share the same stop_event
@@ -385,6 +392,7 @@ class TestStopEventTimeoutBehavior:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
dispatcher = engine._dispatcher dispatcher = engine._dispatcher
@@ -411,6 +419,7 @@ class TestStopEventTimeoutBehavior:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
worker_pool = engine._worker_pool worker_pool = engine._worker_pool
@@ -460,6 +469,7 @@ class TestStopEventResumeBehavior:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Simulate a previous execution that set stop_event # Simulate a previous execution that set stop_event
@@ -490,6 +500,7 @@ class TestWorkerStopBehavior:
graph=mock_graph, graph=mock_graph,
graph_runtime_state=runtime_state, graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
# Get the worker pool and check workers # Get the worker pool and check workers

View File

@@ -32,7 +32,7 @@ from core.variables import (
) )
from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphEngineEvent, GraphEngineEvent,
@@ -309,10 +309,12 @@ class TableTestRunner:
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
min_workers=self.graph_engine_min_workers, config=GraphEngineConfig(
max_workers=self.graph_engine_max_workers, min_workers=self.graph_engine_min_workers,
scale_up_threshold=self.graph_engine_scale_up_threshold, max_workers=self.graph_engine_max_workers,
scale_down_idle_time=self.graph_engine_scale_down_idle_time, scale_up_threshold=self.graph_engine_scale_up_threshold,
scale_down_idle_time=self.graph_engine_scale_down_idle_time,
),
) )
# Execute and collect events # Execute and collect events

View File

@@ -1,4 +1,4 @@
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphRunSucceededEvent, GraphRunSucceededEvent,
@@ -27,6 +27,7 @@ def test_tool_in_chatflow():
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(), command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
) )
events = list(engine.run()) events = list(engine.run())